|
33 | 33 | DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, |
34 | 34 | DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
35 | 35 | DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, |
| 36 | + DATA_CONFIG_RENAME_RETAIN_COLUMNS, |
36 | 37 | DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
37 | 38 | ) |
38 | 39 | from tests.artifacts.testdata import ( |
@@ -1365,3 +1366,57 @@ def test_process_dataset_configs_with_sampling_error( |
1365 | 1366 | (_, _, _, _, _, _) = process_dataargs( |
1366 | 1367 | data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS |
1367 | 1368 | ) |
| 1369 | + |
| 1370 | + |
| 1371 | +@pytest.mark.parametrize( |
| 1372 | + "datafile, rename, retain, final, datasetconfigname", |
| 1373 | + [ |
| 1374 | + ( |
| 1375 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, |
| 1376 | + {"input": "instruction", "output": "response"}, |
| 1377 | + None, |
| 1378 | + ["ID", "Label", "instruction", "response"], |
| 1379 | + DATA_CONFIG_RENAME_RETAIN_COLUMNS, |
| 1380 | + ), |
| 1381 | + ( |
| 1382 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, |
| 1383 | + None, |
| 1384 | + ["ID", "input", "output"], |
| 1385 | + ["ID", "input", "output"], |
| 1386 | + DATA_CONFIG_RENAME_RETAIN_COLUMNS, |
| 1387 | + ), |
| 1388 | + ( |
| 1389 | + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, |
| 1390 | + {"input": "instruction", "output": "response"}, |
| 1391 | + ["Label", "instruction", "response"], |
| 1392 | + ["Label", "instruction", "response"], |
| 1393 | + DATA_CONFIG_RENAME_RETAIN_COLUMNS, |
| 1394 | + ), |
| 1395 | + ], |
| 1396 | +) |
| 1397 | +def test_rename_and_retain_dataset_columns( |
| 1398 | + datafile, rename, retain, final, datasetconfigname |
| 1399 | +): |
| 1400 | + """Test process_dataset_configs for expected output.""" |
| 1401 | + dataprocessor_config = DataPreProcessorConfig() |
| 1402 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| 1403 | + processor = DataPreProcessor( |
| 1404 | + processor_config=dataprocessor_config, |
| 1405 | + tokenizer=tokenizer, |
| 1406 | + ) |
| 1407 | + datasetconfig = [ |
| 1408 | + DataSetConfig( |
| 1409 | + name=datasetconfigname, |
| 1410 | + data_paths=[datafile], |
| 1411 | + rename_columns=rename, |
| 1412 | + retain_columns=retain, |
| 1413 | + ) |
| 1414 | + ] |
| 1415 | + train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig) |
| 1416 | + |
| 1417 | + assert isinstance(train_dataset, Dataset) |
| 1418 | + assert set(train_dataset.column_names) == set(final) |
| 1419 | + |
| 1420 | + with open(datafile, "r") as file: |
| 1421 | + data = json.load(file) |
| 1422 | + assert len(train_dataset) == len(data) |
0 commit comments