Skip to content

Commit e89002d

Browse files
committed
test: multiple datasets with multiple datafiles column names
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent 3fe7425 commit e89002d

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/data/test_data_preprocessing_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,50 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path):
607607
assert formatted_dataset_field in set(train_set.column_names)
608608

609609

610+
@pytest.mark.parametrize(
611+
"datafiles, datasetconfigname",
612+
[
613+
(
614+
[
615+
[
616+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
617+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
618+
],
619+
[
620+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
621+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
622+
],
623+
[
624+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
625+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
626+
],
627+
],
628+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
629+
),
630+
],
631+
)
632+
def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfigname):
633+
"""Ensure that multiple datasets with multiple files are formatted and validated correctly."""
634+
with open(datasetconfigname, "r") as f:
635+
yaml_content = yaml.safe_load(f)
636+
yaml_content["datasets"][0]["data_paths"] = datafiles[0]
637+
yaml_content["datasets"][1]["data_paths"] = datafiles[1]
638+
yaml_content["datasets"][2]["data_paths"] = datafiles[2]
639+
640+
with tempfile.NamedTemporaryFile(
641+
"w", delete=False, suffix=".yaml"
642+
) as temp_yaml_file:
643+
yaml.dump(yaml_content, temp_yaml_file)
644+
temp_yaml_file_path = temp_yaml_file.name
645+
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)
646+
647+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
648+
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
649+
assert isinstance(train_set, Dataset)
650+
column_names = set(["input_ids", "attention_mask", "labels"])
651+
assert set(train_set.column_names) == column_names
652+
653+
610654
@pytest.mark.parametrize(
611655
"data_config_path, list_data_path",
612656
[

0 commit comments

Comments
 (0)