Skip to content

Commit 4168c87

Browse files
authored
Perform dataset mixing via sampling probabilities in data config (#408)
Code to perform dataset sampling via sampling probabilities in data Signed-off-by: Dushyant Behl <[email protected]>
1 parent e6f7a22 commit 4168c87

File tree

7 files changed

+255
-50
lines changed

7 files changed

+255
-50
lines changed

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919

2020
### Constants used for data
2121
PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__))
22-
APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
22+
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
2323
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
2424
)
25-
PRETOKENIZE_JSON_DATA_YAML = os.path.join(
25+
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join(
2626
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml"
2727
)
28-
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join(
28+
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join(
2929
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml"
3030
)
31+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
32+
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
33+
)

tests/artifacts/predefined_data_configs/apply_custom_template.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ datasets:
1111
batched: false
1212
fn_kwargs:
1313
dataset_text_field: "dataset_text_field"
14-
dataset_template: "dataset_template"
14+
template: "dataset_template"
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
dataprocessor:
2+
type: default
3+
sampling_stopping_strategy: first_exhausted
4+
seed: 66
5+
datasets:
6+
- name: dataset_1
7+
sampling: 0.3
8+
data_paths:
9+
- "FILE_PATH"
10+
data_handlers:
11+
- name: tokenize_and_apply_input_masking
12+
arguments:
13+
remove_columns: all
14+
batched: false
15+
fn_kwargs:
16+
input_field_name: input
17+
output_field_name: output
18+
- name: dataset_2
19+
sampling: 0.4
20+
data_paths:
21+
- "FILE_PATH"
22+
data_handlers:
23+
- name: tokenize_and_apply_input_masking
24+
arguments:
25+
remove_columns: all
26+
batched: false
27+
fn_kwargs:
28+
input_field_name: input
29+
output_field_name: output
30+
- name: dataset_3
31+
sampling: 0.3
32+
data_paths:
33+
- "FILE_PATH"
34+
data_handlers:
35+
- name: tokenize_and_apply_input_masking
36+
arguments:
37+
remove_columns: all
38+
batched: false
39+
fn_kwargs:
40+
input_field_name: input
41+
output_field_name: output

tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ datasets:
1010
remove_columns: all
1111
batched: false
1212
fn_kwargs:
13-
input_field: "INPUT"
14-
output_field: "OUTPUT"
13+
input_field_name: "INPUT"
14+
output_field_name: "OUTPUT"

tests/data/test_data_preprocessing_utils.py

Lines changed: 116 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626

2727
# First Party
2828
from tests.artifacts.predefined_data_configs import (
29-
APPLY_CUSTOM_TEMPLATE_YAML,
30-
PRETOKENIZE_JSON_DATA_YAML,
31-
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
29+
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
30+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
31+
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
32+
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
3233
)
3334
from tests.artifacts.testdata import (
3435
MODEL_NAME,
@@ -428,22 +429,22 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
428429
@pytest.mark.parametrize(
429430
"data_config_path, data_path",
430431
[
431-
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
432-
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
433-
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
434-
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
435-
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
436-
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
437-
(
438-
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
432+
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
433+
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
434+
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
435+
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
436+
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
437+
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
438+
(
439+
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
439440
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
440441
),
441442
(
442-
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
443+
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
443444
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
444445
),
445446
(
446-
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
447+
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
447448
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
448449
),
449450
],
@@ -709,3 +710,105 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
709710
with open(datafile, "r") as file:
710711
data = json.load(file)
711712
assert len(train_dataset) == len(data)
713+
714+
715+
@pytest.mark.parametrize(
716+
"datafiles, sampling, datasetconfigname",
717+
[
718+
(
719+
[
720+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
721+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
722+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
723+
],
724+
[0.3, None, 0.3],
725+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
726+
),
727+
(
728+
[
729+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
730+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
731+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
732+
],
733+
[0.3, 0.5, 0.3],
734+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
735+
),
736+
],
737+
)
738+
def test_process_dataset_configs_with_sampling_error(
739+
datafiles, sampling, datasetconfigname
740+
):
741+
742+
data_args = configs.DataArguments()
743+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
744+
TRAIN_ARGS = configs.TrainingArguments(
745+
packing=False,
746+
max_seq_length=1024,
747+
output_dir="tmp", # Not needed but positional
748+
)
749+
750+
with tempfile.NamedTemporaryFile(
751+
"w", delete=False, suffix=".yaml"
752+
) as temp_yaml_file:
753+
with open(datasetconfigname, "r") as f:
754+
data = yaml.safe_load(f)
755+
datasets = data["datasets"]
756+
for i in range(len(datasets)):
757+
d = datasets[i]
758+
d["data_paths"][0] = datafiles[i]
759+
d["sampling"] = sampling[i]
760+
yaml.dump(data, temp_yaml_file)
761+
data_args.data_config_path = temp_yaml_file.name
762+
763+
with pytest.raises(ValueError):
764+
(_, _, _, _, _, _) = process_dataargs(
765+
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
766+
)
767+
768+
769+
@pytest.mark.parametrize(
770+
"datafiles, datasetconfigname",
771+
[
772+
(
773+
[
774+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
775+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
776+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
777+
],
778+
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
779+
),
780+
],
781+
)
782+
def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname):
783+
784+
data_args = configs.DataArguments()
785+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
786+
TRAIN_ARGS = configs.TrainingArguments(
787+
packing=False,
788+
max_seq_length=1024,
789+
output_dir="tmp", # Not needed but positional
790+
)
791+
792+
with tempfile.NamedTemporaryFile(
793+
"w", delete=False, suffix=".yaml"
794+
) as temp_yaml_file:
795+
with open(datasetconfigname, "r") as f:
796+
data = yaml.safe_load(f)
797+
datasets = data["datasets"]
798+
for i in range(len(datasets)):
799+
d = datasets[i]
800+
d["data_paths"][0] = datafiles[i]
801+
yaml.dump(data, temp_yaml_file)
802+
data_args.data_config_path = temp_yaml_file.name
803+
804+
(train_set, eval_set, _, _, _, _) = process_dataargs(
805+
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
806+
)
807+
808+
assert isinstance(train_set, Dataset)
809+
if eval_set:
810+
assert isinstance(eval_set, Dataset)
811+
812+
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
813+
if eval_set:
814+
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))

tuning/data/data_config.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ class DataHandlerConfig:
3232
class DataSetConfig:
3333
name: str
3434
data_paths: List[str]
35-
sampling: Optional[Dict] = None
35+
sampling: Optional[float] = None
3636
data_handlers: Optional[List[DataHandlerConfig]] = None
3737

3838

3939
@dataclass
4040
class DataPreProcessorConfig:
4141
type: Optional[str] = "default"
42+
sampling_stopping_strategy: Optional[str] = "all_exhausted"
43+
# Default seed is not none to ensure reproducability
44+
sampling_seed: Optional[float] = 42
4245

4346

4447
@dataclass
@@ -84,17 +87,12 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
8487
)
8588
p = _p
8689
c.data_paths.append(p)
87-
if "sampling" in kwargs:
88-
sampling_kwargs = kwargs["sampling"]
89-
assert isinstance(
90-
dict, sampling_kwargs
91-
), "sampling arguments should be of the type dict"
92-
if "ratio" in sampling_kwargs:
93-
ratio = sampling_kwargs["ratio"]
94-
assert isinstance(ratio, float) and (
95-
0 <= ratio <= 1.0
96-
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
97-
c.sampling = sampling_kwargs
90+
if "sampling" in kwargs and kwargs["sampling"] is not None:
91+
ratio = kwargs["sampling"]
92+
assert isinstance(ratio, float) and (
93+
0 <= ratio <= 1.0
94+
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
95+
c.sampling = ratio
9896
if "data_handlers" in kwargs:
9997
c.data_handlers = []
10098
for handler in kwargs["data_handlers"]:
@@ -106,6 +104,23 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf
106104
kwargs = dataprocessor_config
107105
c = DataPreProcessorConfig()
108106
assert isinstance(kwargs, dict), "dataprocessor in data_config needs to be a dict"
107+
if "type" in kwargs:
108+
assert isinstance(kwargs["type"], str), "dataprocessor type must be a string"
109+
c.type = kwargs["type"]
110+
if "sampling_stopping_strategy" in kwargs:
111+
strategy = kwargs["sampling_stopping_strategy"]
112+
assert isinstance(
113+
strategy, str
114+
), "dataset sampling stopping strategy must be a string"
115+
assert strategy in [
116+
"first_exhausted",
117+
"all_exhausted",
118+
], "allowed sampling stopping strategies are all_exhausted(default) or first_exhausted"
119+
c.sampling_stopping_strategy = strategy
120+
if "sampling_seed" in kwargs:
121+
seed = kwargs["sampling_seed"]
122+
assert isinstance(seed, int), "sampling seed should be int"
123+
c.sampling_seed = seed
109124
return c
110125

111126

0 commit comments

Comments
 (0)