Skip to content

Commit 5e948f1

Browse files
committed
add initial implementation of dataloader v2
Signed-off-by: Dushyant Behl <[email protected]>
1 parent 7ba3434 commit 5e948f1

File tree

10 files changed

+705
-22
lines changed

10 files changed

+705
-22
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
dataloader:
2+
type: default
3+
datasets:
4+
- name: apply_custom_data_template
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: tokenize_and_apply_instruction_masking
9+
arguments:
10+
remove_columns: all
11+
batched: false
12+
fn_kwargs:
13+
dataset_text_field: "dataset_text_field"
14+
dataset_template: "dataset_template"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
dataloader:
2+
type: default
3+
datasets:
4+
- name: pretokenized_dataset
5+
data_paths:
6+
- "FILE_PATH"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
dataloader:
2+
type: default
3+
datasets:
4+
- name: text_dataset_input_output_masking
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: tokenize_and_apply_instruction_masking
9+
arguments:
10+
remove_columns: all
11+
batched: false
12+
fn_kwargs:
13+
input_field: "INPUT"
14+
output_field: "OUTPUT"

tuning/data/data_config.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
from dataclasses import dataclass
17+
from typing import Dict, List, Optional
18+
import logging
19+
import os
20+
21+
# Local
22+
from tuning.utils.utils import load_yaml_or_json
23+
24+
25+
@dataclass
26+
class DataHandlerConfig:
27+
name: str
28+
arguments: Optional[Dict]
29+
30+
31+
@dataclass
32+
class DataSetConfig:
33+
name: str
34+
data_paths: List[str]
35+
sampling: Optional[Dict] = None
36+
splitter_arguments: Optional[Dict] = None
37+
data_handlers: Optional[List[DataHandlerConfig]] = None
38+
39+
40+
@dataclass
41+
class DataLoaderConfig:
42+
type: Optional[str] = "default"
43+
streaming: Optional[bool] = None
44+
45+
46+
@dataclass
47+
class DataConfig:
48+
dataloader: DataLoaderConfig
49+
datasets: List[DataSetConfig]
50+
51+
52+
def _validate_data_handler_config(data_handler) -> DataHandlerConfig:
53+
kwargs = data_handler
54+
assert isinstance(kwargs, dict), "data_handlers in data_config needs to be a dict"
55+
assert "name" in kwargs and isinstance(
56+
kwargs["name"], str
57+
), "data_handlers need to have a name with type str"
58+
assert "arguments" in kwargs, "data handlers need to have arguments"
59+
assert isinstance(
60+
kwargs["arguments"], dict
61+
), "data handler arguments should be of the type dict"
62+
return DataHandlerConfig(**kwargs)
63+
64+
65+
def _validate_dataset_config(dataset_config) -> DataSetConfig:
66+
c = DataSetConfig()
67+
kwargs = dataset_config
68+
assert isinstance(kwargs, dict), "dataset_config in data_config needs to be a dict"
69+
if "name" in kwargs:
70+
assert isinstance(kwargs["name"], str), "dataset name should be string"
71+
c.name = kwargs["name"]
72+
if "data_paths" not in kwargs:
73+
raise ValueError("data_paths should be specified for each dataset")
74+
else:
75+
data_paths = kwargs["data_paths"]
76+
# TODO: Support that data_paths can be a directory or directories
77+
assert (isinstance(data_paths, List), "data_paths should be an array of files")
78+
c.data_paths = []
79+
for p in data_paths:
80+
assert isinstance(p, str), f"path {p} should be of the type string"
81+
assert os.path.exists(p), f"data_paths {p} does not exist"
82+
if not os.isabs(p):
83+
_p = os.path.abspath(p)
84+
logging.warning(
85+
f" Provided path {p} is not absolute changing it to {_p}"
86+
)
87+
p = _p
88+
c.data_paths.append(p)
89+
if "sampling" in kwargs:
90+
sampling_kwargs = kwargs["sampling"]
91+
assert isinstance(
92+
Dict, sampling_kwargs
93+
), "sampling arguments should be of the type dict"
94+
if "ratio" in sampling_kwargs:
95+
ratio = sampling_kwargs["ratio"]
96+
assert (
97+
(isinstance(ratio, float) and (0 <= ratio <= 1.0)),
98+
f"sampling ratio: {ratio} should be float and in range [0.0,1.0]",
99+
)
100+
c.sampling = sampling_kwargs
101+
if "splitter_arguments" in kwargs:
102+
splitter_kwargs = kwargs["splitter_arguments"]
103+
assert isinstance(
104+
Dict, splitter_kwargs
105+
), "splitter_arguments should be of the type dict"
106+
c.splitter_arguments = splitter_kwargs
107+
if "data_handlers" in kwargs:
108+
c.data_handlers = []
109+
for handler in kwargs["data_handlers"]:
110+
c.data_handlers.append(_validate_data_handler_config(handler))
111+
return c
112+
113+
114+
def _validate_dataloader_config(dataloader_config) -> DataLoaderConfig:
115+
kwargs = dataloader_config
116+
c = DataLoaderConfig()
117+
assert isinstance(kwargs, dict), "dataloader in data_config needs to be a dict"
118+
if "streaming" in kwargs:
119+
assert (
120+
isinstance(kwargs["streaming"], bool),
121+
"streaming should be a boolean true or false",
122+
)
123+
c.streaming = kwargs["streaming"]
124+
return c
125+
126+
127+
def validate_data_config(dataconfig: DataConfig):
128+
_validate_dataloader_config(dataconfig.dataloader)
129+
for d in dataconfig.datasets:
130+
_validate_dataset_config(d)
131+
132+
133+
def load_and_validate_data_config(data_config_file: str) -> DataConfig:
134+
raw_data = load_yaml_or_json(data_config_file)
135+
assert isinstance(
136+
raw_data, Dict
137+
), f"The provided data_config file is invalid: {data_config_file}"
138+
data_config = DataConfig()
139+
assert "datasets" in raw_data, "datasets should be provided in data config"
140+
assert isinstance(
141+
raw_data["datasets"], List
142+
), "datasets should be provided as a list"
143+
data_config.datasets = []
144+
for d in raw_data["datasets"]:
145+
data_config.datasets.append(_validate_dataset_config(d))
146+
if "dataloader" in data_config:
147+
dataloader = _validate_dataloader_config(raw_data["dataloader"])
148+
data_config.dataloader = dataloader
149+
return data_config

tuning/data/data_handlers.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Definition of some predefined data preprocessing functions that we need.
16+
17+
# Standard
18+
from typing import Dict, List
19+
20+
# Third Party
21+
from transformers import AutoTokenizer
22+
23+
# Local
24+
from tuning.utils.data_utils import custom_data_formatter
25+
from tuning.utils.preprocessing_utils import combine_sequence
26+
27+
28+
def tokenize_and_apply_instruction_masking(
29+
element: Dict[str, str],
30+
tokenizer: AutoTokenizer,
31+
input_field_name: str,
32+
output_field_name: str,
33+
**tokenizer_kwargs,
34+
):
35+
input = element[input_field_name]
36+
output = element[output_field_name]
37+
38+
# TODO: Eventually move the code here
39+
combined = combine_sequence(input, output, eos_token=tokenizer.eos_token)
40+
41+
tokenized_comb_seqs = tokenizer(combined, **tokenizer_kwargs)
42+
tokenized_input = tokenizer(input, **tokenizer_kwargs)
43+
44+
masked_labels = [-100] * len(
45+
tokenized_input.input_ids
46+
) + tokenized_comb_seqs.input_ids[len(tokenized_input.input_ids) :]
47+
48+
# Any benefit of retaining the old columns?
49+
return {
50+
"input_ids": tokenized_comb_seqs.input_ids,
51+
"labels": masked_labels,
52+
"attention_mask": tokenized_comb_seqs.attention_mask,
53+
}
54+
55+
56+
def apply_dataset_formatting(
57+
element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, **kwargs
58+
):
59+
return {
60+
f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token
61+
}
62+
63+
64+
def apply_custom_data_formatting_template(
65+
element: Dict[str, str],
66+
tokenizer: AutoTokenizer,
67+
dataset_text_field: str,
68+
template: str,
69+
**kwargs,
70+
):
71+
template += tokenizer.eos_token
72+
73+
# TODO: Eventually move the code here.
74+
custom_data_formatter(
75+
element=element, formatted_dataset_field=dataset_text_field, template=template
76+
)
77+
78+
79+
AVAILABLE_DATA_HANDLERS = {
80+
"tokenize_and_apply_instruction_masking": tokenize_and_apply_instruction_masking,
81+
"apply_dataset_formatting": apply_dataset_formatting,
82+
"apply_custom_data_formatting_template": apply_dataset_formatting,
83+
}

0 commit comments

Comments
 (0)