1313# limitations under the License.
1414
1515# Standard
16- from typing import Union
16+ from typing import Callable , Dict , Union
1717import logging
1818
1919# Third Party
@@ -55,11 +55,16 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]):
5555
5656# TODO: For now assume only training dataset is passed via data config file.
5757# This is very limited but is done to keep first implementation minimal
58- def _process_dataconfig_file (data_args : DataArguments , tokenizer : AutoTokenizer ):
58+ def _process_dataconfig_file (
59+ data_args : DataArguments ,
60+ tokenizer : AutoTokenizer ,
61+ additional_data_handlers : Dict [str , Callable ] = None ,
62+ ):
5963 data_config = load_and_validate_data_config (data_args .data_config_path )
6064 processor = get_datapreprocessor (
6165 processor_config = data_config .dataprocessor , tokenizer = tokenizer
6266 )
67+ processor .register_data_handlers (additional_data_handlers )
6368 train_dataset = processor .process_dataset_configs (data_config .datasets )
6469
6570 return (train_dataset , None , data_args .dataset_text_field )
@@ -179,14 +184,15 @@ def _process_raw_data_args(
179184 tokenizer : AutoTokenizer ,
180185 packing : bool ,
181186 max_seq_length : int ,
187+ additional_data_handlers : Dict [str , Callable ] = None ,
182188):
183189
184190 # Create a data processor with default processor config
185191 default_processor_config = DataPreProcessorConfig ()
186192 data_processor = get_datapreprocessor (
187193 processor_config = default_processor_config , tokenizer = tokenizer
188194 )
189-
195+ data_processor . register_data_handlers ( additional_data_handlers )
190196 assert isinstance (
191197 data_args .training_data_path , str
192198 ), "Training data path has to be set and str"
@@ -259,7 +265,10 @@ def _process_raw_data_args(
259265# If no data config file is specified, process the remaining data arguments
260266# to determine the use case based on their presence, as explained in _process_raw_data_args.
261267def process_dataargs (
262- data_args : DataArguments , tokenizer : AutoTokenizer , train_args : TrainingArguments
268+ data_args : DataArguments ,
269+ tokenizer : AutoTokenizer ,
270+ train_args : TrainingArguments ,
271+ additional_data_handlers : Dict [str , Callable ] = None ,
263272):
264273 """
265274 Args:
@@ -268,11 +277,17 @@ def process_dataargs(
268277 train_args: TrainingArguments
269278 Training arguments passed to the library
270279 Used for packing and max_seq_length
280+ additional_data_handlers: A Dict of [str, callable] data handlers
281+ which need to be registered with the data preprocessor
271282 Returns:
272283 Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
273- tuple containing train_dataset, eval_dataset, dataset_text_field,
274- data_collator, max_seq_length and dataset_kwargs
275-
284+ tuple containing
285+ train_dataset (Dataset/IterableDataset),
286+ eval_dataset (Dataset/IterableDataset),
287+ dataset_text_field (str),
288+ data_collator (DataCollator)
289+ max_seq_length(int) and
290+ dataset_kwargs (Dict)
276291 """
277292
278293 max_seq_length = min (train_args .max_seq_length , tokenizer .model_max_length )
@@ -290,26 +305,32 @@ def process_dataargs(
290305
291306 if data_args .data_config_path :
292307 train_dataset , eval_dataset , dataset_text_field = _process_dataconfig_file (
293- data_args , tokenizer
308+ data_args , tokenizer , additional_data_handlers
294309 )
295310 else :
296311 train_dataset , eval_dataset , dataset_text_field = _process_raw_data_args (
297- data_args , tokenizer , train_args .packing , max_seq_length
312+ data_args ,
313+ tokenizer ,
314+ train_args .packing ,
315+ max_seq_length ,
316+ additional_data_handlers ,
298317 )
299318
319+ # Note: This check should not be removed.
320+ # Its important to recompute this post handling to
321+ # check if we already tokenized the dataset or not.
322+ is_tokenized_dataset = is_pretokenized_dataset (train_dataset or eval_dataset )
323+
300324 data_collator = get_data_collator (
301325 train_args .packing ,
302326 data_args .response_template ,
303327 tokenizer ,
304- # Note: This check should not be removed.
305- # Its important to recompute this post handling to
306- # check if we already tokenized the dataset or not.
307- is_pretokenized_dataset (train_dataset ),
328+ is_tokenized_dataset ,
308329 max_seq_length ,
309330 )
310331
311332 dataset_kwargs = {}
312- if is_pretokenized_dataset ( train_dataset or eval_dataset ) :
333+ if is_tokenized_dataset :
313334 dataset_kwargs ["skip_prepare_dataset" ] = True
314335
315336 return (
0 commit comments