|
43 | 43 | # Local |
44 | 44 | from tuning.config import configs |
45 | 45 | from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig |
46 | | -from tuning.data.data_preprocessing_utils import ( |
47 | | - combine_sequence, |
48 | | - get_data_collator, |
49 | | - validate_data_args, |
50 | | -) |
51 | | -from tuning.data.data_processors import HFBasedDataPreProcessor, get_datapreprocessor |
| 46 | +from tuning.data.data_preprocessing_utils import get_data_collator |
| 47 | +from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor |
52 | 48 | from tuning.data.setup_dataprocessor import ( |
53 | 49 | _process_dataconfig_file, |
54 | 50 | is_pretokenized_dataset, |
55 | 51 | process_dataargs, |
56 | 52 | ) |
57 | 53 |
|
58 | 54 |
|
59 | | -@pytest.mark.parametrize( |
60 | | - "input_element,output_element,expected_res", |
61 | | - [ |
62 | | - ("foo ", "bar", "foo bar"), |
63 | | - ("foo\n", "bar", "foo\nbar"), |
64 | | - ("foo\t", "bar", "foo\tbar"), |
65 | | - ("foo", "bar", "foo bar"), |
66 | | - ], |
67 | | -) |
68 | | -def test_combine_sequence(input_element, output_element, expected_res): |
69 | | - """Ensure that input / output elements are combined with correct whitespace handling.""" |
70 | | - comb_seq = combine_sequence(input_element, output_element) |
71 | | - assert isinstance(comb_seq, str) |
72 | | - assert comb_seq == expected_res |
73 | | - |
74 | | - |
75 | | -@pytest.mark.parametrize( |
76 | | - "input_element,output_element,expected_res", |
77 | | - [ |
78 | | - ("foo ", "bar", "foo bar"), |
79 | | - ("foo\n", "bar", "foo\nbar"), |
80 | | - ("foo\t", "bar", "foo\tbar"), |
81 | | - ("foo", "bar", "foo bar"), |
82 | | - ], |
83 | | -) |
84 | | -def test_combine_sequence_adds_eos(input_element, output_element, expected_res): |
85 | | - """Ensure that input / output elements are combined with correct whitespace handling.""" |
86 | | - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
87 | | - comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) |
88 | | - expected_res += tokenizer.eos_token |
89 | | - assert isinstance(comb_seq, str) |
90 | | - assert comb_seq == expected_res |
91 | | - |
92 | | - |
93 | 55 | @pytest.mark.parametrize( |
94 | 56 | "datafile, column_names", |
95 | 57 | [ |
@@ -222,7 +184,6 @@ def test_load_dataset_without_dataconfig_and_datafile(): |
222 | 184 | ) |
223 | 185 | def test_is_pretokenized_data(data, result): |
224 | 186 | """Ensure that the correct collator type is fetched based on the data args""" |
225 | | - |
226 | 187 | assert is_pretokenized_dataset(data=data) == result |
227 | 188 |
|
228 | 189 |
|
@@ -361,43 +322,16 @@ def test_get_data_collator( |
361 | 322 | ), |
362 | 323 | ], |
363 | 324 | ) |
364 | | -def test_validate_args(data_args, packing): |
| 325 | +def test_process_data_args_throws_error_where_needed(data_args, packing): |
365 | 326 | """Ensure that respective errors are thrown for incorrect data arguments""" |
366 | 327 | with pytest.raises(ValueError): |
367 | | - is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) |
368 | | - is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) |
369 | | - validate_data_args( |
370 | | - data_args, packing, is_traindata_tokenized, is_evaldata_tokenized |
| 328 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| 329 | + TRAIN_ARGS = configs.TrainingArguments( |
| 330 | + packing=packing, |
| 331 | + max_seq_length=1024, |
| 332 | + output_dir="tmp", # Not needed but positional |
371 | 333 | ) |
372 | | - |
373 | | - |
374 | | -@pytest.mark.parametrize( |
375 | | - "data_args, packing", |
376 | | - [ |
377 | | - # pretokenized train dataset and no validation dataset passed |
378 | | - ( |
379 | | - configs.DataArguments( |
380 | | - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, |
381 | | - ), |
382 | | - False, |
383 | | - ), |
384 | | - # pretokenized train and validation datasets |
385 | | - ( |
386 | | - configs.DataArguments( |
387 | | - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, |
388 | | - validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, |
389 | | - ), |
390 | | - False, |
391 | | - ), |
392 | | - ], |
393 | | -) |
394 | | -def test_validate_args_pretokenized(data_args, packing): |
395 | | - """Ensure that supported data args do not error out when passing pretokenized datasets""" |
396 | | - is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) |
397 | | - is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) |
398 | | - validate_data_args( |
399 | | - data_args, packing, is_traindata_tokenized, is_evaldata_tokenized |
400 | | - ) |
| 334 | + (_, _, _, _, _, _) = process_dataargs(data_args, tokenizer, TRAIN_ARGS) |
401 | 335 |
|
402 | 336 |
|
403 | 337 | @pytest.mark.parametrize( |
@@ -448,11 +382,7 @@ def test_process_dataconfig_file(data_config_path, data_path): |
448 | 382 | data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) |
449 | 383 |
|
450 | 384 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
451 | | - packing = (False,) |
452 | | - max_seq_length = 1024 |
453 | | - (train_set, _, _, _, _, _) = _process_dataconfig_file( |
454 | | - data_args, tokenizer, packing, max_seq_length |
455 | | - ) |
| 385 | + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) |
456 | 386 | assert isinstance(train_set, Dataset) |
457 | 387 | if datasets_name == "text_dataset_input_output_masking": |
458 | 388 | column_names = set(["input_ids", "attention_mask", "labels"]) |
@@ -625,7 +555,7 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname): |
625 | 555 | """Test process_dataset_configs for expected output.""" |
626 | 556 | dataprocessor_config = DataPreProcessorConfig() |
627 | 557 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
628 | | - processor = HFBasedDataPreProcessor( |
| 558 | + processor = DataPreProcessor( |
629 | 559 | processor_config=dataprocessor_config, |
630 | 560 | tokenizer=tokenizer, |
631 | 561 | ) |
|
0 commit comments