Skip to content

Commit a94c197

Browse files
committed
add initial implementation of dataloader v2
1 parent 7ba3434 commit a94c197

File tree

10 files changed

+605
-22
lines changed

10 files changed

+605
-22
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
dataloader:
2+
type: default
3+
datasets:
4+
- name: apply_custom_data_template
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: tokenize_and_apply_instruction_masking
9+
arguments:
10+
remove_columns: all
11+
batched: false
12+
fn_kwargs:
13+
dataset_text_field: "dataset_text_field"
14+
dataset_template: "dataset_template"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
dataloader:
2+
type: default
3+
datasets:
4+
- name: pretokenized_dataset
5+
data_paths:
6+
- "FILE_PATH"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
dataloader:
2+
type: default
3+
datasets:
4+
- name: text_dataset_input_output_masking
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: tokenize_and_apply_instruction_masking
9+
arguments:
10+
remove_columns: all
11+
batched: false
12+
fn_kwargs:
13+
input_field: "INPUT"
14+
output_field: "OUTPUT"

tuning/data/data_config.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import logging
17+
from dataclasses import dataclass
18+
from typing import List, Optional, Dict
19+
20+
from tuning.utils.utils import load_yaml_or_json
21+
22+
@dataclass
23+
class DataHandlerConfig:
24+
name: str
25+
arguments: Optional[Dict]
26+
27+
@dataclass
28+
class DataSetConfig:
29+
name: str
30+
data_paths: List[str]
31+
sampling: Optional[Dict] = None
32+
splitter_arguments: Optional[Dict] = None
33+
data_handlers: Optional[List[DataHandlerConfig]] = None
34+
35+
@dataclass
36+
class DataLoaderConfig:
37+
type: Optional[str] = "default"
38+
streaming: Optional[bool] = None
39+
40+
@dataclass
41+
class DataConfig:
42+
dataloader: DataLoaderConfig
43+
datasets: List[DataSetConfig]
44+
45+
def _validate_data_handler_config(data_handler) -> DataHandlerConfig:
46+
kwargs = data_handler
47+
assert isinstance(kwargs, dict), "data_handlers in data_config needs to be a dict"
48+
assert "name" in kwargs and isinstance(kwargs['name'], str), "data_handlers need to have a name with type str"
49+
assert "arguments" in kwargs, "data handlers need to have arguments"
50+
assert isinstance(kwargs['arguments'], dict), "data handler arguments should be of the type dict"
51+
return DataHandlerConfig(**kwargs)
52+
53+
def _validate_dataset_config(dataset_config) -> DataSetConfig:
54+
c = DataSetConfig()
55+
kwargs = dataset_config
56+
assert isinstance(kwargs, dict), "dataset_config in data_config needs to be a dict"
57+
if "name" in kwargs:
58+
assert isinstance(kwargs["name"], str), "dataset name should be string"
59+
c.name = kwargs['name']
60+
if "data_paths" not in kwargs:
61+
raise ValueError("data_paths should be specified for each dataset")
62+
else:
63+
data_paths = kwargs['data_paths']
64+
# TODO: Support that data_paths can be a directory or directories
65+
assert(isinstance(data_paths, List), "data_paths should be an array of files")
66+
c.data_paths = []
67+
for p in data_paths:
68+
assert isinstance(p, str), f"path {p} should be of the type string"
69+
assert os.path.exists(p), f"data_paths {p} does not exist"
70+
if not os.isabs(p):
71+
_p = os.path.abspath(p)
72+
logging.warning(f' Provided path {p} is not absolute changing it to {_p}')
73+
p = _p
74+
c.data_paths.append(p)
75+
if "sampling" in kwargs:
76+
sampling_kwargs = kwargs['sampling']
77+
assert isinstance(Dict, sampling_kwargs), "sampling arguments should be of the type dict"
78+
if "ratio" in sampling_kwargs:
79+
ratio = sampling_kwargs['ratio']
80+
assert((isinstance(ratio, float) and (0 <= ratio <= 1.0)),
81+
f"sampling ratio: {ratio} should be float and in range [0.0,1.0]")
82+
c.sampling = sampling_kwargs
83+
if "splitter_arguments" in kwargs:
84+
splitter_kwargs = kwargs['splitter_arguments']
85+
assert isinstance(Dict, splitter_kwargs), "splitter_arguments should be of the type dict"
86+
c.splitter_arguments = splitter_kwargs
87+
if "data_handlers" in kwargs:
88+
c.data_handlers = []
89+
for handler in kwargs['data_handlers']:
90+
c.data_handlers.append(_validate_data_handler_config(handler))
91+
return c
92+
93+
def _validate_dataloader_config(dataloader_config) -> DataLoaderConfig:
94+
kwargs = dataloader_config
95+
c = DataLoaderConfig()
96+
assert isinstance(kwargs, dict), "dataloader in data_config needs to be a dict"
97+
if "streaming" in kwargs:
98+
assert (isinstance(kwargs['streaming'], bool),
99+
"streaming should be a boolean true or false")
100+
c.streaming = kwargs['streaming']
101+
return c
102+
103+
def validate_data_config(dataconfig: DataConfig):
104+
_validate_dataloader_config(dataconfig.dataloader)
105+
for d in dataconfig.datasets:
106+
_validate_dataset_config(d)
107+
108+
def load_and_validate_data_config(data_config_file: str) -> DataConfig:
109+
raw_data = load_yaml_or_json(data_config_file)
110+
assert isinstance(raw_data, Dict), f"The provided data_config file is invalid: {data_config_file}"
111+
data_config = DataConfig()
112+
assert "datasets" in raw_data, "datasets should be provided in data config"
113+
assert isinstance(raw_data['datasets'], List), "datasets should be provided as a list"
114+
data_config.datasets = []
115+
for d in raw_data['datasets']:
116+
data_config.datasets.append(_validate_dataset_config(d))
117+
if "dataloader" in data_config:
118+
dataloader = _validate_dataloader_config(raw_data['dataloader'])
119+
data_config.dataloader = dataloader
120+
return data_config

tuning/data/data_handlers.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Definition of some predefined data preprocessing functions that we need.
16+
17+
from typing import Dict, List
18+
19+
from transformers import AutoTokenizer
20+
from tuning.utils.preprocessing_utils import combine_sequence
21+
from tuning.utils.data_utils import custom_data_formatter
22+
23+
def tokenize_and_apply_instruction_masking(element: Dict[str, str],
24+
tokenizer: AutoTokenizer,
25+
input_field_name: str,
26+
output_field_name: str,
27+
**tokenizer_kwargs):
28+
input = element[input_field_name]
29+
output = element[output_field_name]
30+
31+
# TODO: Eventually move the code here
32+
combined = combine_sequence(input, output, eos_token=tokenizer.eos_token)
33+
34+
tokenized_comb_seqs = tokenizer(combined, **tokenizer_kwargs)
35+
tokenized_input = tokenizer(input, **tokenizer_kwargs)
36+
37+
masked_labels = [-100] * len(
38+
tokenized_input.input_ids
39+
) + tokenized_comb_seqs.input_ids[len(tokenized_input.input_ids) :]
40+
41+
# Any benefit of retaining the old columns?
42+
return {
43+
"input_ids": tokenized_comb_seqs.input_ids,
44+
"labels": masked_labels,
45+
"attention_mask": tokenized_comb_seqs.attention_mask,
46+
}
47+
48+
def apply_dataset_formatting(element: Dict[str, str],
49+
tokenizer: AutoTokenizer,
50+
dataset_text_field: str,
51+
**kwargs):
52+
return {
53+
f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token
54+
}
55+
56+
def apply_custom_data_formatting_template(element: Dict[str, str],
57+
tokenizer: AutoTokenizer,
58+
dataset_text_field: str,
59+
template: str,
60+
**kwargs):
61+
template += tokenizer.eos_token
62+
63+
# TODO: Eventually move the code here.
64+
custom_data_formatter(element=element, formatted_dataset_field=dataset_text_field,
65+
template=template)
66+
67+
AVAILABLE_DATA_HANDLERS = {
68+
"tokenize_and_apply_instruction_masking" : tokenize_and_apply_instruction_masking,
69+
"apply_dataset_formatting" : apply_dataset_formatting,
70+
"apply_custom_data_formatting_template" : apply_dataset_formatting
71+
}

0 commit comments

Comments
 (0)