Skip to content

Commit 82495d3

Browse files
committed
feat: resume functionality
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 2f404c6 commit 82495d3

File tree

3 files changed

+99
-56
lines changed

3 files changed

+99
-56
lines changed

tuning/data/data_processors.py

Lines changed: 94 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -452,35 +452,9 @@ def split_dataset(
452452
)
453453
return split_datasets
454454

455-
def _process_datasets_for_odm(
456-
self,
457-
processed_datasets: List[
458-
Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]]
459-
],
460-
) -> Tuple[
461-
Dict[str, Union[Dataset, IterableDataset]],
462-
Dict[str, Union[Dataset, IterableDataset]],
463-
]:
464-
train_split = "train"
465-
eval_split = "test"
466-
train_datasets_dict = {}
467-
eval_datasets_dict = {}
468-
for d, raw in processed_datasets:
469-
if train_split in raw:
470-
train_datasets_dict[d.name] = raw[train_split]
471-
if eval_split in raw:
472-
eval_datasets_dict[d.name] = raw[eval_split]
473-
return train_datasets_dict, eval_datasets_dict
474-
475455
def _process_dataset_configs(
476-
self, dataset_configs: List[DataSetConfig], odm_config=None
477-
) -> Union[
478-
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
479-
Tuple[
480-
Dict[str, Union[Dataset, IterableDataset]],
481-
Dict[str, Union[Dataset, IterableDataset]],
482-
],
483-
]:
456+
self, dataset_configs: List[DataSetConfig]
457+
) -> Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]]:
484458

485459
if not dataset_configs:
486460
raise ValueError(
@@ -530,13 +504,7 @@ def _process_dataset_configs(
530504

531505
# Append the processed datasets to the final dict
532506
processed_datasets.append((d, raw_datasets))
533-
if odm_config:
534-
logger.info(
535-
"Sampling probabilities are ignored if provided"
536-
"and are not used for concatenation. Instead"
537-
"online data mixing plugin handles it."
538-
)
539-
return self._process_datasets_for_odm(processed_datasets)
507+
540508
train_datasets = []
541509
train_sampling_probabilities = []
542510
validation_datasets = []
@@ -623,14 +591,8 @@ def _process_dataset_configs(
623591
return train_dataset, eval_dataset
624592

625593
def process_dataset_configs(
626-
self, dataset_configs: List[DataSetConfig], odm_config=None
627-
) -> Union[
628-
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
629-
Tuple[
630-
Dict[str, Union[Dataset, IterableDataset]],
631-
Dict[str, Union[Dataset, IterableDataset]],
632-
],
633-
]:
594+
self, dataset_configs: List[DataSetConfig]
595+
) -> Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]]:
634596
train_dataset = eval_dataset = None
635597

636598
# Use partial state as recommended by HF documentation for process control
@@ -643,23 +605,108 @@ def process_dataset_configs(
643605
# as we want to reuse HF cache and not redo computation on all nodes
644606
# For rationale see https://github.com/huggingface/trl/pull/3106
645607
with state.main_process_first():
646-
train_dataset, eval_dataset = self._process_dataset_configs(
647-
dataset_configs, odm_config
648-
)
608+
train_dataset, eval_dataset = self._process_dataset_configs(dataset_configs)
649609

650610
logger.info("Processed train dataset {}".format(train_dataset))
651611
logger.info("Processed eval dataset {}".format(eval_dataset))
652612

653613
return train_dataset, eval_dataset
654614

655615

616+
class ODMDataPreProcessor(DataPreProcessor):
617+
def _process_datasets_for_odm(
618+
self,
619+
processed_datasets: List[
620+
Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]]
621+
],
622+
) -> Tuple[
623+
Dict[str, Union[Dataset, IterableDataset]],
624+
Dict[str, Union[Dataset, IterableDataset]],
625+
]:
626+
train_split = "train"
627+
eval_split = "test"
628+
train_datasets_dict = {}
629+
eval_datasets_dict = {}
630+
for d, raw in processed_datasets:
631+
if train_split in raw:
632+
train_datasets_dict[d.name] = raw[train_split]
633+
if eval_split in raw:
634+
eval_datasets_dict[d.name] = raw[eval_split]
635+
return train_datasets_dict, eval_datasets_dict
636+
637+
def _process_dataset_configs(
638+
self, dataset_configs: List[DataSetConfig]
639+
) -> Tuple[
640+
Dict[str, Union[Dataset, IterableDataset]],
641+
Dict[str, Union[Dataset, IterableDataset]],
642+
]:
643+
644+
if not dataset_configs:
645+
raise ValueError(
646+
"No dataset configs provided. Provided Dataset configs is None."
647+
)
648+
649+
train_split = "train" # default
650+
eval_split = "test"
651+
652+
processed_datasets = []
653+
654+
logger.info("Starting DataPreProcessor...")
655+
# Now Iterate over the multiple datasets provided to us to process
656+
for d in dataset_configs:
657+
logger.info("Loading the dataset - %s", d.name)
658+
659+
# In future the streaming etc go as kwargs of this function
660+
loaded_dataset = self.load_dataset(d, self.processor_config.streaming)
661+
logger.info("Loaded raw dataset : %s", str(loaded_dataset))
662+
663+
if d.split is not None:
664+
loaded_dataset = self.split_dataset(d, loaded_dataset)
665+
666+
# Create a raw dataset which is a Dict container to house all Datasets
667+
raw_datasets = (
668+
IterableDatasetDict()
669+
if isinstance(loaded_dataset, (IterableDataset, IterableDatasetDict))
670+
else DatasetDict()
671+
)
672+
673+
splits_to_keep = [train_split, eval_split]
674+
if isinstance(loaded_dataset, (Dataset, IterableDataset)):
675+
# Assume all is train split
676+
raw_datasets[train_split] = loaded_dataset
677+
else:
678+
for k, v in loaded_dataset.items():
679+
if k in splits_to_keep:
680+
raw_datasets[k] = v
681+
682+
if d.data_handlers: # Execute the datahandlers
683+
for data_handler_config in d.data_handlers:
684+
raw_datasets = self._execute_data_handlers(
685+
raw_datasets=raw_datasets,
686+
data_handler_config=data_handler_config,
687+
datasetName=d.name,
688+
)
689+
690+
# Append the processed datasets to the final dict
691+
processed_datasets.append((d, raw_datasets))
692+
logger.info(
693+
"Sampling probabilities are ignored if provided"
694+
"and are not used for concatenation. Instead"
695+
"online data mixing plugin handles it."
696+
)
697+
return self._process_datasets_for_odm(processed_datasets)
698+
699+
656700
def get_datapreprocessor(
657701
processor_config: DataPreProcessorConfig,
658702
tokenizer: AutoTokenizer,
659703
processor: AutoProcessor = None,
660704
additional_data_handlers: Dict[str, DataHandler] = None,
661705
) -> DataPreProcessor:
662-
data_processor = DataPreProcessor(
706+
data_processor_cls = DataPreProcessor
707+
if processor_config.type == "odm":
708+
data_processor_cls = ODMDataPreProcessor
709+
data_processor = data_processor_cls(
663710
processor_config=processor_config,
664711
tokenizer=tokenizer,
665712
processor=processor,

tuning/data/setup_dataprocessor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def process_dataconfig_file(
7373
processor: AutoProcessor = None,
7474
is_multipack: bool = False,
7575
is_padding_free: bool = False,
76-
odm_config: ODMConfig = None,
7776
):
7877
"""
7978
Args:
@@ -155,7 +154,7 @@ def process_dataconfig_file(
155154
tokenizer.chat_template = data_processor.processor_config.chat_template
156155

157156
train_dataset, eval_dataset = data_processor.process_dataset_configs(
158-
data_config.datasets, odm_config=odm_config
157+
data_config.datasets
159158
)
160159

161160
return (train_dataset, eval_dataset, data_args.dataset_text_field)
@@ -348,7 +347,6 @@ def _process_raw_data_args(
348347
additional_data_handlers: Dict[str, DataHandler] = None,
349348
is_padding_free: bool = False,
350349
processor: AutoProcessor = None,
351-
odm_config: ODMConfig = None,
352350
):
353351

354352
if data_args.data_config_path is not None:
@@ -448,7 +446,7 @@ def _process_raw_data_args(
448446
dataset_configs.append(eval_dataset_config)
449447

450448
train_dataset, eval_dataset = data_processor.process_dataset_configs(
451-
dataset_configs, odm_config=odm_config
449+
dataset_configs
452450
)
453451

454452
return (train_dataset, eval_dataset, dataset_text_field)
@@ -635,7 +633,6 @@ def process_dataargs(
635633
processor,
636634
is_multipack,
637635
is_padding_free,
638-
odm_config=odm_config,
639636
)
640637
else:
641638
train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args(
@@ -646,7 +643,6 @@ def process_dataargs(
646643
additional_data_handlers,
647644
is_padding_free,
648645
processor,
649-
odm_config=odm_config,
650646
)
651647

652648
if train_args.eval_strategy != "no" and eval_dataset is None:

tuning/sft_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ def train(
130130
)
131131

132132
resume_from_checkpoint = None
133-
if training_args.output_dir:
134-
os.makedirs(training_args.output_dir, exist_ok=True)
135-
logger.info("using the output directory at %s", training_args.output_dir)
133+
if train_args.output_dir:
134+
os.makedirs(train_args.output_dir, exist_ok=True)
135+
logger.info("using the output directory at %s", train_args.output_dir)
136136

137137
# Check if resume flag is not passed (None), or if flag is true and
138138
# output_dir has checkpoints then get last checkpoint from output_dir

0 commit comments

Comments
 (0)