2323import datasets
2424import pytest
2525import yaml
26-
26+ import sys
27+ import os
28+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), '../../' )))
2729# First Party
2830from tests .artifacts .predefined_data_configs import (
2931 DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML ,
@@ -498,7 +500,7 @@ def test_process_dataconfig_file(data_config_path, data_path):
498500
499501
500502@pytest .mark .parametrize (
501- "data_config_path, list_data_path " ,
503+ "data_config_path, data_path_list " ,
502504 [
503505 (
504506 DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML ,
@@ -571,11 +573,11 @@ def test_process_dataconfig_file(data_config_path, data_path):
571573 ),
572574 ],
573575)
574- def test_process_dataconfig_multiple_files (data_config_path , list_data_path ):
576+ def test_process_dataconfig_multiple_files (data_config_path , data_path_list ):
575577 """Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file."""
576578 with open (data_config_path , "r" ) as f :
577579 yaml_content = yaml .safe_load (f )
578- yaml_content ["datasets" ][0 ]["data_paths" ] = list_data_path
580+ yaml_content ["datasets" ][0 ]["data_paths" ] = data_path_list
579581 datasets_name = yaml_content ["datasets" ][0 ]["name" ]
580582
581583 # Modify input_field_name and output_field_name according to dataset
@@ -635,7 +637,7 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path):
635637 ),
636638 ],
637639)
638- def test_process_dataconfig_multiple_datasets_datafiles (datafiles , datasetconfigname ):
640+ def test_process_dataconfig_multiple_datasets_datafiles_sampling (datafiles , datasetconfigname ):
639641 """Ensure that multiple datasets with multiple files are formatted and validated correctly."""
640642 with open (datasetconfigname , "r" ) as f :
641643 yaml_content = yaml .safe_load (f )
@@ -651,14 +653,26 @@ def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfig
651653 data_args = configs .DataArguments (data_config_path = temp_yaml_file_path )
652654
653655 tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
654- (train_set , _ , _ ) = _process_dataconfig_file (data_args , tokenizer )
656+ TRAIN_ARGS = configs .TrainingArguments (
657+ packing = False ,
658+ max_seq_length = 1024 ,
659+ output_dir = "tmp" ,
660+ )
661+ (train_set , eval_set , _ , _ , _ , _ ) = process_dataargs (
662+ data_args = data_args , tokenizer = tokenizer , train_args = TRAIN_ARGS
663+ )
664+
655665 assert isinstance (train_set , Dataset )
656- column_names = set (["input_ids" , "attention_mask" , "labels" ])
657- assert set (train_set .column_names ) == column_names
666+ if eval_set :
667+ assert isinstance (eval_set , Dataset )
668+
669+ assert set (["input_ids" , "attention_mask" , "labels" ]).issubset (set (train_set .column_names ))
670+ if eval_set :
671+ assert set (["input_ids" , "attention_mask" , "labels" ]).issubset (set (eval_set .column_names ))
658672
659673
660674@pytest .mark .parametrize (
661- "data_config_path, list_data_path " ,
675+ "data_config_path, data_path_list " ,
662676 [
663677 (
664678 DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML ,
@@ -682,12 +696,12 @@ def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfig
682696 ],
683697)
684698def test_process_dataconfig_multiple_files_varied_data_formats (
685- data_config_path , list_data_path
699+ data_config_path , data_path_list
686700):
687701 """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
688702 with open (data_config_path , "r" ) as f :
689703 yaml_content = yaml .safe_load (f )
690- yaml_content ["datasets" ][0 ]["data_paths" ] = list_data_path
704+ yaml_content ["datasets" ][0 ]["data_paths" ] = data_path_list
691705 datasets_name = yaml_content ["datasets" ][0 ]["name" ]
692706
693707 # Modify input_field_name and output_field_name according to dataset
@@ -719,7 +733,7 @@ def test_process_dataconfig_multiple_files_varied_data_formats(
719733
720734
721735@pytest .mark .parametrize (
722- "data_config_path, list_data_path " ,
736+ "data_config_path, data_path_list " ,
723737 [
724738 (
725739 DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML ,
@@ -739,12 +753,12 @@ def test_process_dataconfig_multiple_files_varied_data_formats(
739753 ],
740754)
741755def test_process_dataconfig_multiple_files_varied_types (
742- data_config_path , list_data_path
756+ data_config_path , data_path_list
743757):
744758 """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
745759 with open (data_config_path , "r" ) as f :
746760 yaml_content = yaml .safe_load (f )
747- yaml_content ["datasets" ][0 ]["data_paths" ] = list_data_path
761+ yaml_content ["datasets" ][0 ]["data_paths" ] = data_path_list
748762 datasets_name = yaml_content ["datasets" ][0 ]["name" ]
749763
750764 # Modify input_field_name and output_field_name according to dataset
@@ -1048,51 +1062,3 @@ def test_process_dataset_configs_with_sampling_error(
10481062 (_ , _ , _ , _ , _ , _ ) = process_dataargs (
10491063 data_args = data_args , tokenizer = tokenizer , train_args = TRAIN_ARGS
10501064 )
1051-
1052-
1053- @pytest .mark .parametrize (
1054- "datafiles, datasetconfigname" ,
1055- [
1056- (
1057- [
1058- TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW ,
1059- TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL ,
1060- TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET ,
1061- ],
1062- DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML ,
1063- ),
1064- ],
1065- )
1066- def test_process_dataset_configs_with_sampling (datafiles , datasetconfigname ):
1067-
1068- data_args = configs .DataArguments ()
1069- tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
1070- TRAIN_ARGS = configs .TrainingArguments (
1071- packing = False ,
1072- max_seq_length = 1024 ,
1073- output_dir = "tmp" , # Not needed but positional
1074- )
1075-
1076- with tempfile .NamedTemporaryFile (
1077- "w" , delete = False , suffix = ".yaml"
1078- ) as temp_yaml_file :
1079- with open (datasetconfigname , "r" ) as f :
1080- data = yaml .safe_load (f )
1081- datasets = data ["datasets" ]
1082- for i in range (len (datasets )):
1083- d = datasets [i ]
1084- d ["data_paths" ][0 ] = datafiles [i ]
1085- yaml .dump (data , temp_yaml_file )
1086- data_args .data_config_path = temp_yaml_file .name
1087-
1088- (train_set , eval_set , _ , _ , _ , _ ) = process_dataargs (
1089- data_args = data_args , tokenizer = tokenizer , train_args = TRAIN_ARGS
1090- )
1091-
1092- assert isinstance (train_set , Dataset )
1093- if eval_set :
1094- assert isinstance (eval_set , Dataset )
1095-
1096- assert set (["input_ids" , "labels" ]).issubset (set (train_set .column_names ))
1097- if eval_set :
1098- assert set (["input_ids" , "labels" ]).issubset (set (eval_set .column_names ))
0 commit comments