Skip to content

Commit 689ee41

Browse files
authored
feat: Expose additional data handlers as an argument in train (#409)
* Expose additional data handlers as an argument to the train function. Signed-off-by: Dushyant Behl <[email protected]>
1 parent 4168c87 commit 689ee41

File tree

4 files changed

+177
-31
lines changed

4 files changed

+177
-31
lines changed

tests/test_sft_trainer.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
"""Unit Tests for SFT Trainer.
1616
"""
1717

18+
# pylint: disable=too-many-lines
19+
1820
# Standard
21+
from dataclasses import asdict
1922
import copy
2023
import json
2124
import os
@@ -46,6 +49,13 @@
4649
from tuning import sft_trainer
4750
from tuning.config import configs, peft_config
4851
from tuning.config.tracker_configs import FileLoggingTrackerConfig
52+
from tuning.data.data_config import (
53+
DataConfig,
54+
DataHandlerConfig,
55+
DataPreProcessorConfig,
56+
DataSetConfig,
57+
)
58+
from tuning.data.data_handlers import apply_dataset_formatting
4959

5060
MODEL_ARGS = configs.ModelArguments(
5161
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
@@ -1124,3 +1134,100 @@ def test_pretokenized_dataset_wrong_format():
11241134
# is essentially swallowing a KeyError here.
11251135
with pytest.raises(ValueError):
11261136
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
1137+
1138+
1139+
###########################################################################
1140+
### Tests for checking different cases for the argument additional_handlers
1141+
### The argument `additional_handlers` in train::sft_trainer.py is used to pass
1142+
### extra data handlers which should be a Dict[str,callable]
1143+
1144+
1145+
@pytest.mark.parametrize(
1146+
"additional_handlers",
1147+
[
1148+
"thisisnotokay",
1149+
[],
1150+
{lambda x: {"x": x}: "notokayeither"},
1151+
{"thisisfine": "thisisnot"},
1152+
],
1153+
)
1154+
def test_run_with_bad_additional_data_handlers(additional_handlers):
1155+
"""Ensure that bad additional_handlers argument (which is not Dict[str,callable])
1156+
throws an error"""
1157+
with tempfile.TemporaryDirectory() as tempdir:
1158+
train_args = copy.deepcopy(TRAIN_ARGS)
1159+
train_args.output_dir = tempdir
1160+
1161+
with pytest.raises(
1162+
ValueError, match="Handlers should be of type Dict, str to callable"
1163+
):
1164+
sft_trainer.train(
1165+
MODEL_ARGS,
1166+
DATA_ARGS,
1167+
train_args,
1168+
PEFT_PT_ARGS,
1169+
additional_data_handlers=additional_handlers,
1170+
)
1171+
1172+
1173+
def test_run_with_additional_data_handlers_as_none():
1174+
"""Ensure that additional_handlers as None should work."""
1175+
with tempfile.TemporaryDirectory() as tempdir:
1176+
train_args = copy.deepcopy(TRAIN_ARGS)
1177+
train_args.output_dir = tempdir
1178+
1179+
sft_trainer.train(
1180+
MODEL_ARGS,
1181+
DATA_ARGS,
1182+
train_args,
1183+
PEFT_PT_ARGS,
1184+
additional_data_handlers=None,
1185+
)
1186+
_validate_training(tempdir)
1187+
1188+
1189+
def test_run_by_passing_additional_data_handlers():
1190+
"""Ensure that good additional_handlers argument can take a
1191+
data handler and can successfully run a e2e training."""
1192+
# This is my test handler
1193+
TEST_HANDLER = "my_test_handler"
1194+
1195+
def test_handler(element, tokenizer, **kwargs):
1196+
return apply_dataset_formatting(element, tokenizer, "custom_formatted_field")
1197+
1198+
# This data config calls for data handler to be applied to dataset
1199+
preprocessor_config = DataPreProcessorConfig()
1200+
handler_config = DataHandlerConfig(name="my_test_handler", arguments=None)
1201+
dataaset_config = DataSetConfig(
1202+
name="test_dataset",
1203+
data_paths=TWITTER_COMPLAINTS_DATA_JSON,
1204+
data_handlers=[handler_config],
1205+
)
1206+
data_config = DataConfig(
1207+
dataprocessor=preprocessor_config, datasets=[dataaset_config]
1208+
)
1209+
1210+
# dump the data config to a file, also test if json data config works
1211+
with tempfile.NamedTemporaryFile(
1212+
"w", delete=False, suffix=".json"
1213+
) as temp_data_file:
1214+
data_config_raw = json.dumps(asdict(data_config))
1215+
temp_data_file.write(data_config_raw)
1216+
data_config_path = temp_data_file.name
1217+
1218+
# now launch sft trainer after registering data handler
1219+
with tempfile.TemporaryDirectory() as tempdir:
1220+
train_args = copy.deepcopy(TRAIN_ARGS)
1221+
train_args.output_dir = tempdir
1222+
data_args = copy.deepcopy(DATA_ARGS)
1223+
data_args.data_config_path = data_config_path
1224+
data_args.dataset_text_field = "custom_formatted_field"
1225+
1226+
sft_trainer.train(
1227+
MODEL_ARGS,
1228+
DATA_ARGS,
1229+
train_args,
1230+
PEFT_PT_ARGS,
1231+
additional_data_handlers={TEST_HANDLER: test_handler},
1232+
)
1233+
_validate_training(tempdir)

tuning/data/data_processors.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Dict, List, Union
16+
from typing import Callable, Dict, List, Union
1717
import logging
1818
import os
1919

@@ -35,7 +35,7 @@ class DataPreProcessor:
3535
tokenizer = None
3636
data_config: DataConfig = None
3737
processor_config: DataPreProcessorConfig = None
38-
registered_handlers: Dict[str, callable] = None
38+
registered_handlers: Dict[str, Callable] = None
3939

4040
def __init__(
4141
self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer
@@ -46,8 +46,27 @@ def __init__(
4646
# Initialize other objects
4747
self.registered_handlers = {}
4848

49-
def register_data_handler(self, name: str, func: callable):
49+
# Auto register available data handlers
50+
for k, v in AVAILABLE_DATA_HANDLERS.items():
51+
self.registered_handlers[k] = v
52+
53+
def register_data_handler(self, name: str, func: Callable):
54+
if not isinstance(name, str) or not callable(func):
55+
raise ValueError("Handlers should be of type Dict, str to callable")
56+
if name in self.registered_handlers:
57+
logging.warning(
58+
"Handler name '%s' already exists and will be overwritten", name
59+
)
5060
self.registered_handlers[name] = func
61+
logging.info("Registered new handler %s", name)
62+
63+
def register_data_handlers(self, handlers: Dict[str, Callable]):
64+
if handlers is None:
65+
return
66+
if not isinstance(handlers, Dict):
67+
raise ValueError("Handlers should be of type Dict, str to callable")
68+
for k, v in handlers.items():
69+
self.register_data_handler(name=k, func=v)
5170

5271
def load_dataset(
5372
self,
@@ -238,19 +257,14 @@ def process_dataset_configs(
238257
return train_dataset
239258

240259

241-
def autoregister_available_handlers(processor: DataPreProcessor):
242-
if processor is None:
243-
return
244-
for name, func in AVAILABLE_DATA_HANDLERS.items():
245-
processor.register_data_handler(name=name, func=func)
246-
247-
248260
def get_datapreprocessor(
249-
processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer
261+
processor_config: DataPreProcessorConfig,
262+
tokenizer: AutoTokenizer,
263+
additional_data_handlers: Dict[str, Callable] = None,
250264
) -> DataPreProcessor:
251265
processor = DataPreProcessor(
252266
processor_config=processor_config,
253267
tokenizer=tokenizer,
254268
)
255-
autoregister_available_handlers(processor)
269+
processor.register_data_handlers(additional_data_handlers)
256270
return processor

tuning/data/setup_dataprocessor.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Union
16+
from typing import Callable, Dict, Union
1717
import logging
1818

1919
# Third Party
@@ -55,10 +55,16 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]):
5555

5656
# TODO: For now assume only training dataset is passed via data config file.
5757
# This is very limited but is done to keep first implementation minimal
58-
def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer):
58+
def _process_dataconfig_file(
59+
data_args: DataArguments,
60+
tokenizer: AutoTokenizer,
61+
additional_data_handlers: Dict[str, Callable] = None,
62+
):
5963
data_config = load_and_validate_data_config(data_args.data_config_path)
6064
processor = get_datapreprocessor(
61-
processor_config=data_config.dataprocessor, tokenizer=tokenizer
65+
processor_config=data_config.dataprocessor,
66+
tokenizer=tokenizer,
67+
additional_data_handlers=additional_data_handlers,
6268
)
6369
train_dataset = processor.process_dataset_configs(data_config.datasets)
6470

@@ -179,14 +185,16 @@ def _process_raw_data_args(
179185
tokenizer: AutoTokenizer,
180186
packing: bool,
181187
max_seq_length: int,
188+
additional_data_handlers: Dict[str, Callable] = None,
182189
):
183190

184191
# Create a data processor with default processor config
185192
default_processor_config = DataPreProcessorConfig()
186193
data_processor = get_datapreprocessor(
187-
processor_config=default_processor_config, tokenizer=tokenizer
194+
processor_config=default_processor_config,
195+
tokenizer=tokenizer,
196+
additional_data_handlers=additional_data_handlers,
188197
)
189-
190198
assert isinstance(
191199
data_args.training_data_path, str
192200
), "Training data path has to be set and str"
@@ -259,7 +267,10 @@ def _process_raw_data_args(
259267
# If no data config file is specified, process the remaining data arguments
260268
# to determine the use case based on their presence, as explained in _process_raw_data_args.
261269
def process_dataargs(
262-
data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments
270+
data_args: DataArguments,
271+
tokenizer: AutoTokenizer,
272+
train_args: TrainingArguments,
273+
additional_data_handlers: Dict[str, Callable] = None,
263274
):
264275
"""
265276
Args:
@@ -268,11 +279,17 @@ def process_dataargs(
268279
train_args: TrainingArguments
269280
Training arguments passed to the library
270281
Used for packing and max_seq_length
282+
additional_data_handlers: A Dict of [str, callable] data handlers
283+
which need to be registered with the data preprocessor
271284
Returns:
272285
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
273-
tuple containing train_dataset, eval_dataset, dataset_text_field,
274-
data_collator, max_seq_length and dataset_kwargs
275-
286+
tuple containing
287+
train_dataset (Dataset/IterableDataset),
288+
eval_dataset (Dataset/IterableDataset),
289+
dataset_text_field (str),
290+
data_collator (DataCollator)
291+
max_seq_length(int) and
292+
dataset_kwargs (Dict)
276293
"""
277294

278295
max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
@@ -290,26 +307,32 @@ def process_dataargs(
290307

291308
if data_args.data_config_path:
292309
train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file(
293-
data_args, tokenizer
310+
data_args, tokenizer, additional_data_handlers
294311
)
295312
else:
296313
train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args(
297-
data_args, tokenizer, train_args.packing, max_seq_length
314+
data_args,
315+
tokenizer,
316+
train_args.packing,
317+
max_seq_length,
318+
additional_data_handlers,
298319
)
299320

321+
# Note: This check should not be removed.
322+
# Its important to recompute this post handling to
323+
# check if we already tokenized the dataset or not.
324+
is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset)
325+
300326
data_collator = get_data_collator(
301327
train_args.packing,
302328
data_args.response_template,
303329
tokenizer,
304-
# Note: This check should not be removed.
305-
# Its important to recompute this post handling to
306-
# check if we already tokenized the dataset or not.
307-
is_pretokenized_dataset(train_dataset),
330+
is_tokenized_dataset,
308331
max_seq_length,
309332
)
310333

311334
dataset_kwargs = {}
312-
if is_pretokenized_dataset(train_dataset or eval_dataset):
335+
if is_tokenized_dataset:
313336
dataset_kwargs["skip_prepare_dataset"] = True
314337

315338
return (

tuning/sft_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Dict, List, Optional, Union
16+
from typing import Callable, Dict, List, Optional, Union
1717
import dataclasses
1818
import json
1919
import logging
@@ -85,6 +85,7 @@ def train(
8585
attention_and_distributed_packing_config: Optional[
8686
AttentionAndDistributedPackingConfig
8787
] = None,
88+
additional_data_handlers: Optional[Dict[str, Callable]] = None,
8889
) -> tuple[SFTTrainer, dict]:
8990
"""Call the SFTTrainer
9091
@@ -113,7 +114,8 @@ def train(
113114
Should be used in combination with quantized_lora_config. Also currently
114115
fused_lora and fast_kernels must used together (may change in future). \
115116
attention_and_distributed_packing_config: Used for padding-free attention and multipack.
116-
117+
additional_data_handlers: Dict [str:Callable] of any extra data handlers \
118+
to be registered with the data preprocessor
117119
Returns:
118120
Tuple: Instance of SFTTrainer , some metadata in a dict
119121
Metadata contains information on number of added tokens while tuning.
@@ -297,7 +299,7 @@ def train(
297299
data_collator,
298300
max_seq_length,
299301
dataset_kwargs,
300-
) = process_dataargs(data_args, tokenizer, train_args)
302+
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
301303
additional_metrics["data_preprocessing_time"] = (
302304
time.time() - data_preprocessing_time
303305
)

0 commit comments

Comments
 (0)