|
26 | 26 |
|
27 | 27 | # First Party |
28 | 28 | from tests.artifacts.predefined_data_configs import ( |
29 | | - APPLY_CUSTOM_TEMPLATE_YAML, |
30 | | - PRETOKENIZE_JSON_DATA_YAML, |
31 | | - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
| 29 | + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, |
| 30 | + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
| 31 | + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, |
| 32 | + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
32 | 33 | ) |
33 | 34 | from tests.artifacts.testdata import ( |
34 | 35 | MODEL_NAME, |
@@ -428,22 +429,22 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): |
428 | 429 | @pytest.mark.parametrize( |
429 | 430 | "data_config_path, data_path", |
430 | 431 | [ |
431 | | - (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), |
432 | | - (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), |
433 | | - (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), |
434 | | - (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), |
435 | | - (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), |
436 | | - (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), |
437 | | - ( |
438 | | - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
| 432 | + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), |
| 433 | + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), |
| 434 | + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), |
| 435 | + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), |
| 436 | + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), |
| 437 | + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), |
| 438 | + ( |
| 439 | + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
439 | 440 | TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, |
440 | 441 | ), |
441 | 442 | ( |
442 | | - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
| 443 | + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
443 | 444 | TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, |
444 | 445 | ), |
445 | 446 | ( |
446 | | - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
| 447 | + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
447 | 448 | TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, |
448 | 449 | ), |
449 | 450 | ], |
@@ -709,3 +710,105 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname): |
709 | 710 | with open(datafile, "r") as file: |
710 | 711 | data = json.load(file) |
711 | 712 | assert len(train_dataset) == len(data) |
| 713 | + |
| 714 | + |
| 715 | +@pytest.mark.parametrize( |
| 716 | + "datafiles, sampling, datasetconfigname", |
| 717 | + [ |
| 718 | + ( |
| 719 | + [ |
| 720 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, |
| 721 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, |
| 722 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, |
| 723 | + ], |
| 724 | + [0.3, None, 0.3], |
| 725 | + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
| 726 | + ), |
| 727 | + ( |
| 728 | + [ |
| 729 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, |
| 730 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, |
| 731 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, |
| 732 | + ], |
| 733 | + [0.3, 0.5, 0.3], |
| 734 | + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
| 735 | + ), |
| 736 | + ], |
| 737 | +) |
| 738 | +def test_process_dataset_configs_with_sampling_error( |
| 739 | + datafiles, sampling, datasetconfigname |
| 740 | +): |
| 741 | + |
| 742 | + data_args = configs.DataArguments() |
| 743 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| 744 | + TRAIN_ARGS = configs.TrainingArguments( |
| 745 | + packing=False, |
| 746 | + max_seq_length=1024, |
| 747 | + output_dir="tmp", # Not needed but positional |
| 748 | + ) |
| 749 | + |
| 750 | + with tempfile.NamedTemporaryFile( |
| 751 | + "w", delete=False, suffix=".yaml" |
| 752 | + ) as temp_yaml_file: |
| 753 | + with open(datasetconfigname, "r") as f: |
| 754 | + data = yaml.safe_load(f) |
| 755 | + datasets = data["datasets"] |
| 756 | + for i in range(len(datasets)): |
| 757 | + d = datasets[i] |
| 758 | + d["data_paths"][0] = datafiles[i] |
| 759 | + d["sampling"] = sampling[i] |
| 760 | + yaml.dump(data, temp_yaml_file) |
| 761 | + data_args.data_config_path = temp_yaml_file.name |
| 762 | + |
| 763 | + with pytest.raises(ValueError): |
| 764 | + (_, _, _, _, _, _) = process_dataargs( |
| 765 | + data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS |
| 766 | + ) |
| 767 | + |
| 768 | + |
| 769 | +@pytest.mark.parametrize( |
| 770 | + "datafiles, datasetconfigname", |
| 771 | + [ |
| 772 | + ( |
| 773 | + [ |
| 774 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, |
| 775 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, |
| 776 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, |
| 777 | + ], |
| 778 | + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
| 779 | + ), |
| 780 | + ], |
| 781 | +) |
| 782 | +def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname): |
| 783 | + |
| 784 | + data_args = configs.DataArguments() |
| 785 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| 786 | + TRAIN_ARGS = configs.TrainingArguments( |
| 787 | + packing=False, |
| 788 | + max_seq_length=1024, |
| 789 | + output_dir="tmp", # Not needed but positional |
| 790 | + ) |
| 791 | + |
| 792 | + with tempfile.NamedTemporaryFile( |
| 793 | + "w", delete=False, suffix=".yaml" |
| 794 | + ) as temp_yaml_file: |
| 795 | + with open(datasetconfigname, "r") as f: |
| 796 | + data = yaml.safe_load(f) |
| 797 | + datasets = data["datasets"] |
| 798 | + for i in range(len(datasets)): |
| 799 | + d = datasets[i] |
| 800 | + d["data_paths"][0] = datafiles[i] |
| 801 | + yaml.dump(data, temp_yaml_file) |
| 802 | + data_args.data_config_path = temp_yaml_file.name |
| 803 | + |
| 804 | + (train_set, eval_set, _, _, _, _) = process_dataargs( |
| 805 | + data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS |
| 806 | + ) |
| 807 | + |
| 808 | + assert isinstance(train_set, Dataset) |
| 809 | + if eval_set: |
| 810 | + assert isinstance(eval_set, Dataset) |
| 811 | + |
| 812 | + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) |
| 813 | + if eval_set: |
| 814 | + assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names)) |
0 commit comments