Skip to content

Commit cad3a2d

Browse files
committed
Expose additional data handlers as an argument to the train function.
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent 4168c87 commit cad3a2d

File tree

3 files changed

+55
-28
lines changed

3 files changed

+55
-28
lines changed

tuning/data/data_processors.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Dict, List, Union
16+
from typing import Callable, Dict, List, Union
1717
import logging
1818
import os
1919

@@ -35,7 +35,7 @@ class DataPreProcessor:
3535
tokenizer = None
3636
data_config: DataConfig = None
3737
processor_config: DataPreProcessorConfig = None
38-
registered_handlers: Dict[str, callable] = None
38+
registered_handlers: Dict[str, Callable] = None
3939

4040
def __init__(
4141
self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer
@@ -46,9 +46,20 @@ def __init__(
4646
# Initialize other objects
4747
self.registered_handlers = {}
4848

49-
def register_data_handler(self, name: str, func: callable):
49+
def register_data_handler(self, name: str, func: Callable):
50+
assert isinstance(name, str), "Handler name should be of str type"
51+
assert callable(func), "Handler should be a callable routine"
5052
self.registered_handlers[name] = func
5153

54+
def register_data_handlers(self, handlers: Dict[str, Callable]):
55+
if handlers is None:
56+
return
57+
assert isinstance(
58+
handlers, Dict
59+
), "Handlers should be of type Dict[str:Callable]"
60+
for k, v in handlers.items():
61+
self.register_data_handler(name=k, func=v)
62+
5263
def load_dataset(
5364
self,
5465
datasetconfig: DataSetConfig,
@@ -238,19 +249,12 @@ def process_dataset_configs(
238249
return train_dataset
239250

240251

241-
def autoregister_available_handlers(processor: DataPreProcessor):
242-
if processor is None:
243-
return
244-
for name, func in AVAILABLE_DATA_HANDLERS.items():
245-
processor.register_data_handler(name=name, func=func)
246-
247-
248252
def get_datapreprocessor(
249253
processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer
250254
) -> DataPreProcessor:
251255
processor = DataPreProcessor(
252256
processor_config=processor_config,
253257
tokenizer=tokenizer,
254258
)
255-
autoregister_available_handlers(processor)
259+
processor.register_data_handlers(AVAILABLE_DATA_HANDLERS)
256260
return processor

tuning/data/setup_dataprocessor.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Union
16+
from typing import Callable, Dict, Union
1717
import logging
1818

1919
# Third Party
@@ -55,11 +55,16 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]):
5555

5656
# TODO: For now assume only training dataset is passed via data config file.
5757
# This is very limited but is done to keep first implementation minimal
58-
def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer):
58+
def _process_dataconfig_file(
59+
data_args: DataArguments,
60+
tokenizer: AutoTokenizer,
61+
additional_data_handlers: Dict[str, Callable] = None,
62+
):
5963
data_config = load_and_validate_data_config(data_args.data_config_path)
6064
processor = get_datapreprocessor(
6165
processor_config=data_config.dataprocessor, tokenizer=tokenizer
6266
)
67+
processor.register_data_handlers(additional_data_handlers)
6368
train_dataset = processor.process_dataset_configs(data_config.datasets)
6469

6570
return (train_dataset, None, data_args.dataset_text_field)
@@ -179,14 +184,15 @@ def _process_raw_data_args(
179184
tokenizer: AutoTokenizer,
180185
packing: bool,
181186
max_seq_length: int,
187+
additional_data_handlers: Dict[str, Callable] = None,
182188
):
183189

184190
# Create a data processor with default processor config
185191
default_processor_config = DataPreProcessorConfig()
186192
data_processor = get_datapreprocessor(
187193
processor_config=default_processor_config, tokenizer=tokenizer
188194
)
189-
195+
data_processor.register_data_handlers(additional_data_handlers)
190196
assert isinstance(
191197
data_args.training_data_path, str
192198
), "Training data path has to be set and str"
@@ -259,7 +265,10 @@ def _process_raw_data_args(
259265
# If no data config file is specified, process the remaining data arguments
260266
# to determine the use case based on their presence, as explained in _process_raw_data_args.
261267
def process_dataargs(
262-
data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments
268+
data_args: DataArguments,
269+
tokenizer: AutoTokenizer,
270+
train_args: TrainingArguments,
271+
additional_data_handlers: Dict[str, Callable] = None,
263272
):
264273
"""
265274
Args:
@@ -268,11 +277,17 @@ def process_dataargs(
268277
train_args: TrainingArguments
269278
Training arguments passed to the library
270279
Used for packing and max_seq_length
280+
additional_data_handlers: A Dict of [str, callable] data handlers
281+
which need to be registered with the data preprocessor
271282
Returns:
272283
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
273-
tuple containing train_dataset, eval_dataset, dataset_text_field,
274-
data_collator, max_seq_length and dataset_kwargs
275-
284+
tuple containing
285+
train_dataset (Dataset/IterableDataset),
286+
eval_dataset (Dataset/IterableDataset),
287+
dataset_text_field (str),
288+
data_collator (DataCollator)
289+
max_seq_length(int) and
290+
dataset_kwargs (Dict)
276291
"""
277292

278293
max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
@@ -290,26 +305,32 @@ def process_dataargs(
290305

291306
if data_args.data_config_path:
292307
train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file(
293-
data_args, tokenizer
308+
data_args, tokenizer, additional_data_handlers
294309
)
295310
else:
296311
train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args(
297-
data_args, tokenizer, train_args.packing, max_seq_length
312+
data_args,
313+
tokenizer,
314+
train_args.packing,
315+
max_seq_length,
316+
additional_data_handlers,
298317
)
299318

319+
# Note: This check should not be removed.
320+
# Its important to recompute this post handling to
321+
# check if we already tokenized the dataset or not.
322+
is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset)
323+
300324
data_collator = get_data_collator(
301325
train_args.packing,
302326
data_args.response_template,
303327
tokenizer,
304-
# Note: This check should not be removed.
305-
# Its important to recompute this post handling to
306-
# check if we already tokenized the dataset or not.
307-
is_pretokenized_dataset(train_dataset),
328+
is_tokenized_dataset,
308329
max_seq_length,
309330
)
310331

311332
dataset_kwargs = {}
312-
if is_pretokenized_dataset(train_dataset or eval_dataset):
333+
if is_tokenized_dataset:
313334
dataset_kwargs["skip_prepare_dataset"] = True
314335

315336
return (

tuning/sft_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Dict, List, Optional, Union
16+
from typing import Callable, Dict, List, Optional, Union
1717
import dataclasses
1818
import json
1919
import logging
@@ -85,6 +85,7 @@ def train(
8585
attention_and_distributed_packing_config: Optional[
8686
AttentionAndDistributedPackingConfig
8787
] = None,
88+
additional_data_handlers: Optional[Dict[str, Callable]] = None,
8889
) -> tuple[SFTTrainer, dict]:
8990
"""Call the SFTTrainer
9091
@@ -113,7 +114,8 @@ def train(
113114
Should be used in combination with quantized_lora_config. Also currently
114115
fused_lora and fast_kernels must used together (may change in future). \
115116
attention_and_distributed_packing_config: Used for padding-free attention and multipack.
116-
117+
additional_data_handlers: Dict [str:Callable] of any extra data handlers \
118+
to be registered with the data preprocessor
117119
Returns:
118120
Tuple: Instance of SFTTrainer , some metadata in a dict
119121
Metadata contains information on number of added tokens while tuning.
@@ -297,7 +299,7 @@ def train(
297299
data_collator,
298300
max_seq_length,
299301
dataset_kwargs,
300-
) = process_dataargs(data_args, tokenizer, train_args)
302+
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
301303
additional_metrics["data_preprocessing_time"] = (
302304
time.time() - data_preprocessing_time
303305
)

0 commit comments

Comments
 (0)