Skip to content

Commit 3fce172

Browse files
committed
PR Changes
Signed-off-by: Abhishek <[email protected]>
1 parent 4ba1c04 commit 3fce172

File tree

2 files changed

+29
-63
lines changed

2 files changed

+29
-63
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ indent-string=' '
333333
max-line-length=100
334334

335335
# Maximum number of lines in a module.
336-
max-module-lines=1400
336+
max-module-lines=1200
337337

338338
# Allow the body of a class to be on the same line as the declaration if body
339339
# contains single statement.

tests/data/test_data_preprocessing_utils.py

Lines changed: 28 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import datasets
2424
import pytest
2525
import yaml
26-
26+
import sys
27+
import os
28+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
2729
# First Party
2830
from tests.artifacts.predefined_data_configs import (
2931
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
@@ -498,7 +500,7 @@ def test_process_dataconfig_file(data_config_path, data_path):
498500

499501

500502
@pytest.mark.parametrize(
501-
"data_config_path, list_data_path",
503+
"data_config_path, data_path_list",
502504
[
503505
(
504506
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
@@ -571,11 +573,11 @@ def test_process_dataconfig_file(data_config_path, data_path):
571573
),
572574
],
573575
)
574-
def test_process_dataconfig_multiple_files(data_config_path, list_data_path):
576+
def test_process_dataconfig_multiple_files(data_config_path, data_path_list):
575577
"""Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file."""
576578
with open(data_config_path, "r") as f:
577579
yaml_content = yaml.safe_load(f)
578-
yaml_content["datasets"][0]["data_paths"] = list_data_path
580+
yaml_content["datasets"][0]["data_paths"] = data_path_list
579581
datasets_name = yaml_content["datasets"][0]["name"]
580582

581583
# Modify input_field_name and output_field_name according to dataset
@@ -635,7 +637,7 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path):
635637
),
636638
],
637639
)
638-
def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfigname):
640+
def test_process_dataconfig_multiple_datasets_datafiles_sampling(datafiles, datasetconfigname):
639641
"""Ensure that multiple datasets with multiple files are formatted and validated correctly."""
640642
with open(datasetconfigname, "r") as f:
641643
yaml_content = yaml.safe_load(f)
@@ -651,14 +653,26 @@ def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfig
651653
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)
652654

653655
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
654-
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
656+
TRAIN_ARGS = configs.TrainingArguments(
657+
packing=False,
658+
max_seq_length=1024,
659+
output_dir="tmp",
660+
)
661+
(train_set, eval_set, _, _, _, _) = process_dataargs(
662+
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
663+
)
664+
655665
assert isinstance(train_set, Dataset)
656-
column_names = set(["input_ids", "attention_mask", "labels"])
657-
assert set(train_set.column_names) == column_names
666+
if eval_set:
667+
assert isinstance(eval_set, Dataset)
668+
669+
assert set(["input_ids", "attention_mask", "labels"]).issubset(set(train_set.column_names))
670+
if eval_set:
671+
assert set(["input_ids", "attention_mask", "labels"]).issubset(set(eval_set.column_names))
658672

659673

660674
@pytest.mark.parametrize(
661-
"data_config_path, list_data_path",
675+
"data_config_path, data_path_list",
662676
[
663677
(
664678
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
@@ -682,12 +696,12 @@ def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfig
682696
],
683697
)
684698
def test_process_dataconfig_multiple_files_varied_data_formats(
685-
data_config_path, list_data_path
699+
data_config_path, data_path_list
686700
):
687701
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
688702
with open(data_config_path, "r") as f:
689703
yaml_content = yaml.safe_load(f)
690-
yaml_content["datasets"][0]["data_paths"] = list_data_path
704+
yaml_content["datasets"][0]["data_paths"] = data_path_list
691705
datasets_name = yaml_content["datasets"][0]["name"]
692706

693707
# Modify input_field_name and output_field_name according to dataset
@@ -719,7 +733,7 @@ def test_process_dataconfig_multiple_files_varied_data_formats(
719733

720734

721735
@pytest.mark.parametrize(
722-
"data_config_path, list_data_path",
736+
"data_config_path, data_path_list",
723737
[
724738
(
725739
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
@@ -739,12 +753,12 @@ def test_process_dataconfig_multiple_files_varied_data_formats(
739753
],
740754
)
741755
def test_process_dataconfig_multiple_files_varied_types(
742-
data_config_path, list_data_path
756+
data_config_path, data_path_list
743757
):
744758
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
745759
with open(data_config_path, "r") as f:
746760
yaml_content = yaml.safe_load(f)
747-
yaml_content["datasets"][0]["data_paths"] = list_data_path
761+
yaml_content["datasets"][0]["data_paths"] = data_path_list
748762
datasets_name = yaml_content["datasets"][0]["name"]
749763

750764
# Modify input_field_name and output_field_name according to dataset
@@ -1048,51 +1062,3 @@ def test_process_dataset_configs_with_sampling_error(
10481062
(_, _, _, _, _, _) = process_dataargs(
10491063
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
10501064
)
1051-
1052-
1053-
@pytest.mark.parametrize(
1054-
"datafiles, datasetconfigname",
1055-
[
1056-
(
1057-
[
1058-
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
1059-
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
1060-
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
1061-
],
1062-
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
1063-
),
1064-
],
1065-
)
1066-
def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname):
1067-
1068-
data_args = configs.DataArguments()
1069-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
1070-
TRAIN_ARGS = configs.TrainingArguments(
1071-
packing=False,
1072-
max_seq_length=1024,
1073-
output_dir="tmp", # Not needed but positional
1074-
)
1075-
1076-
with tempfile.NamedTemporaryFile(
1077-
"w", delete=False, suffix=".yaml"
1078-
) as temp_yaml_file:
1079-
with open(datasetconfigname, "r") as f:
1080-
data = yaml.safe_load(f)
1081-
datasets = data["datasets"]
1082-
for i in range(len(datasets)):
1083-
d = datasets[i]
1084-
d["data_paths"][0] = datafiles[i]
1085-
yaml.dump(data, temp_yaml_file)
1086-
data_args.data_config_path = temp_yaml_file.name
1087-
1088-
(train_set, eval_set, _, _, _, _) = process_dataargs(
1089-
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
1090-
)
1091-
1092-
assert isinstance(train_set, Dataset)
1093-
if eval_set:
1094-
assert isinstance(eval_set, Dataset)
1095-
1096-
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
1097-
if eval_set:
1098-
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))

0 commit comments

Comments
 (0)