@@ -607,6 +607,50 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path):
607607 assert formatted_dataset_field in set (train_set .column_names )
608608
609609
610+ @pytest .mark .parametrize (
611+ "datafiles, datasetconfigname" ,
612+ [
613+ (
614+ [
615+ [
616+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET ,
617+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET ,
618+ ],
619+ [
620+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON ,
621+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON ,
622+ ],
623+ [
624+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL ,
625+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL ,
626+ ],
627+ ],
628+ DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML ,
629+ ),
630+ ],
631+ )
632+ def test_process_dataconfig_multiple_datasets_datafiles (datafiles , datasetconfigname ):
633+ """Ensure that multiple datasets with multiple files are formatted and validated correctly."""
634+ with open (datasetconfigname , "r" ) as f :
635+ yaml_content = yaml .safe_load (f )
636+ yaml_content ["datasets" ][0 ]["data_paths" ] = datafiles [0 ]
637+ yaml_content ["datasets" ][1 ]["data_paths" ] = datafiles [1 ]
638+ yaml_content ["datasets" ][2 ]["data_paths" ] = datafiles [2 ]
639+
640+ with tempfile .NamedTemporaryFile (
641+ "w" , delete = False , suffix = ".yaml"
642+ ) as temp_yaml_file :
643+ yaml .dump (yaml_content , temp_yaml_file )
644+ temp_yaml_file_path = temp_yaml_file .name
645+ data_args = configs .DataArguments (data_config_path = temp_yaml_file_path )
646+
647+ tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
648+ (train_set , _ , _ ) = _process_dataconfig_file (data_args , tokenizer )
649+ assert isinstance (train_set , Dataset )
650+ column_names = set (["input_ids" , "attention_mask" , "labels" ])
651+ assert set (train_set .column_names ) == column_names
652+
653+
610654@pytest .mark .parametrize (
611655 "data_config_path, list_data_path" ,
612656 [
0 commit comments