-
Notifications
You must be signed in to change notification settings - Fork 65
test: Add unit tests to test multiple files in single dataset #412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ce82af1
3fe7425
e89002d
4ba1c04
3fce172
68a0f50
5905e23
6f13d9a
83d0127
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): | |
| (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), | ||
| (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), | ||
| (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), | ||
| (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), | ||
| (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), | ||
| (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), | ||
| (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), | ||
| (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
|
|
@@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): | |
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
| ), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_process_dataconfig_file(data_config_path, data_path): | ||
|
|
@@ -491,6 +497,234 @@ def test_process_dataconfig_file(data_config_path, data_path): | |
| assert formatted_dataset_field in set(train_set.column_names) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "data_config_path, list_data_path", | ||
|
||
| [ | ||
| ( | ||
| DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
| [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
| [TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSONL], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
| [TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
| [TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
| [TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
| [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_TOKENIZED_PARQUET, | ||
| TWITTER_COMPLAINTS_TOKENIZED_PARQUET, | ||
| ], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
| [TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
| ], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, | ||
| ], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
| ], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, | ||
| ], | ||
| ), | ||
| ], | ||
| ) | ||
| def test_process_dataconfig_multiple_files(data_config_path, list_data_path): | ||
|
||
| """Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file.""" | ||
| with open(data_config_path, "r") as f: | ||
| yaml_content = yaml.safe_load(f) | ||
| yaml_content["datasets"][0]["data_paths"] = list_data_path | ||
| datasets_name = yaml_content["datasets"][0]["name"] | ||
|
|
||
| # Modify input_field_name and output_field_name according to dataset | ||
| if datasets_name == "text_dataset_input_output_masking": | ||
| yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
| "input_field_name": "input", | ||
| "output_field_name": "output", | ||
| } | ||
|
|
||
| # Modify dataset_text_field and template according to dataset | ||
| formatted_dataset_field = "formatted_data_field" | ||
| if datasets_name == "apply_custom_data_template": | ||
| template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
| yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
| "dataset_text_field": formatted_dataset_field, | ||
| "template": template, | ||
| } | ||
|
|
||
| with tempfile.NamedTemporaryFile( | ||
| "w", delete=False, suffix=".yaml" | ||
| ) as temp_yaml_file: | ||
| yaml.dump(yaml_content, temp_yaml_file) | ||
| temp_yaml_file_path = temp_yaml_file.name | ||
| data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
| (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) | ||
| assert isinstance(train_set, Dataset) | ||
| if datasets_name == "text_dataset_input_output_masking": | ||
| column_names = set(["input_ids", "attention_mask", "labels"]) | ||
| assert set(train_set.column_names) == column_names | ||
| elif datasets_name == "pretokenized_dataset": | ||
| assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) | ||
| elif datasets_name == "apply_custom_data_template": | ||
| assert formatted_dataset_field in set(train_set.column_names) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "data_config_path, list_data_path", | ||
| [ | ||
| ( | ||
| DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
| [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_TOKENIZED_JSONL, | ||
| TWITTER_COMPLAINTS_TOKENIZED_ARROW, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
| ], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, | ||
| ], | ||
| ), | ||
| ], | ||
| ) | ||
| def test_process_dataconfig_multiple_files_varied_data_formats( | ||
| data_config_path, list_data_path | ||
| ): | ||
| """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" | ||
| with open(data_config_path, "r") as f: | ||
| yaml_content = yaml.safe_load(f) | ||
| yaml_content["datasets"][0]["data_paths"] = list_data_path | ||
| datasets_name = yaml_content["datasets"][0]["name"] | ||
|
|
||
| # Modify input_field_name and output_field_name according to dataset | ||
| if datasets_name == "text_dataset_input_output_masking": | ||
| yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
| "input_field_name": "input", | ||
| "output_field_name": "output", | ||
| } | ||
|
|
||
| # Modify dataset_text_field and template according to dataset | ||
| formatted_dataset_field = "formatted_data_field" | ||
| if datasets_name == "apply_custom_data_template": | ||
| template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
| yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
| "dataset_text_field": formatted_dataset_field, | ||
| "template": template, | ||
| } | ||
|
|
||
| with tempfile.NamedTemporaryFile( | ||
| "w", delete=False, suffix=".yaml" | ||
| ) as temp_yaml_file: | ||
| yaml.dump(yaml_content, temp_yaml_file) | ||
| temp_yaml_file_path = temp_yaml_file.name | ||
| data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
| with pytest.raises(AssertionError): | ||
| (_, _, _) = _process_dataconfig_file(data_args, tokenizer) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "data_config_path, list_data_path", | ||
| [ | ||
| ( | ||
| DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
| [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
| [ | ||
| TWITTER_COMPLAINTS_TOKENIZED_JSON, | ||
| TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
| ], | ||
| ), | ||
| ( | ||
| DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
| [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON], | ||
| ), | ||
| ], | ||
| ) | ||
| def test_process_dataconfig_multiple_files_varied_types( | ||
|
||
| data_config_path, list_data_path | ||
| ): | ||
| """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" | ||
| with open(data_config_path, "r") as f: | ||
| yaml_content = yaml.safe_load(f) | ||
| yaml_content["datasets"][0]["data_paths"] = list_data_path | ||
| datasets_name = yaml_content["datasets"][0]["name"] | ||
|
|
||
| # Modify input_field_name and output_field_name according to dataset | ||
| if datasets_name == "text_dataset_input_output_masking": | ||
| yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
| "input_field_name": "input", | ||
| "output_field_name": "output", | ||
| } | ||
|
|
||
| # Modify dataset_text_field and template according to dataset | ||
| formatted_dataset_field = "formatted_data_field" | ||
| if datasets_name == "apply_custom_data_template": | ||
| template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
| yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
| "dataset_text_field": formatted_dataset_field, | ||
| "template": template, | ||
| } | ||
|
|
||
| with tempfile.NamedTemporaryFile( | ||
| "w", delete=False, suffix=".yaml" | ||
| ) as temp_yaml_file: | ||
| yaml.dump(yaml_content, temp_yaml_file) | ||
| temp_yaml_file_path = temp_yaml_file.name | ||
| data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
| with pytest.raises(datasets.exceptions.DatasetGenerationCastError): | ||
| (_, _, _) = _process_dataconfig_file(data_args, tokenizer) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "data_args", | ||
| [ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this segregation.