Skip to content

Commit 452c13b

Browse files
authored
feat: Adopt resumption feature of online data mixing (#617)
* feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * fix: refactor code Signed-off-by: Mehant Kammakomati <[email protected]> * fix: refactor code Signed-off-by: Mehant Kammakomati <[email protected]> * fix: refactor code Signed-off-by: Mehant Kammakomati <[email protected]> * fix: refactor sampling weight Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 4ec1340 commit 452c13b

File tree

7 files changed

+156
-116
lines changed

7 files changed

+156
-116
lines changed

tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ datasets:
1313
split:
1414
train: 0.8
1515
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
16-
sampling: 0.3 # ignored
16+
sampling: 0.3 # used as starting weights for online data mixing
1717
data_paths:
1818
- "FILE_PATH"
1919
data_handlers:
@@ -28,7 +28,7 @@ datasets:
2828
split:
2929
train: 0.6
3030
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
31-
sampling: 0.4 # ignored
31+
sampling: 0.4 # used as starting weights for online data mixing
3232
data_paths:
3333
- "FILE_PATH"
3434
data_handlers:
@@ -43,7 +43,7 @@ datasets:
4343
split:
4444
train: 0.4
4545
validation: 0.1 # validation set is also used in ODM reward computation when reward_type is validation_loss.
46-
sampling: 0.3 # ignored
46+
sampling: 0.3 # used as starting weights for online data mixing
4747
data_paths:
4848
- "FILE_PATH"
4949
data_handlers:

tests/data/test_data_preprocessing.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def test_process_dataconfig_file_with_streaming(data_config_path, data_path):
768768
output_dir="tmp", # Not needed but positional
769769
)
770770

771-
(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
771+
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
772772
assert isinstance(train_set, IterableDataset)
773773
if datasets_name == "text_dataset_input_output_masking":
774774
column_names = set(["input_ids", "attention_mask", "labels"])
@@ -1017,7 +1017,7 @@ def test_process_dataconfig_file(data_config_path, data_path):
10171017
output_dir="tmp", # Not needed but positional
10181018
)
10191019

1020-
(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
1020+
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
10211021
assert isinstance(train_set, Dataset)
10221022
if datasets_name == "text_dataset_input_output_masking":
10231023
column_names = set(["input_ids", "attention_mask", "labels"])
@@ -1107,7 +1107,7 @@ def test_process_datahandler_eos_token(data_config_path, data_path, add_eos_toke
11071107
output_dir="tmp", # Not needed but positional
11081108
)
11091109

1110-
(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
1110+
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
11111111
assert isinstance(train_set, Dataset)
11121112
if datasets_name == "text_dataset_input_output_masking":
11131113
column_names = set(["input_ids", "attention_mask", "labels"])
@@ -1258,7 +1258,7 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list):
12581258
output_dir="tmp", # Not needed but positional
12591259
)
12601260

1261-
(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
1261+
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
12621262
assert isinstance(train_set, Dataset)
12631263
if datasets_name == "text_dataset_input_output_masking":
12641264
column_names = set(["input_ids", "attention_mask", "labels"])
@@ -1330,7 +1330,7 @@ def test_process_dataconfig_multiple_files_folders_with_globbing(
13301330
output_dir="tmp", # Not needed but positional
13311331
)
13321332

1333-
(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
1333+
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
13341334
assert isinstance(train_set, Dataset)
13351335
assert set(["input_ids", "attention_mask", "labels"]).issubset(
13361336
set(train_set.column_names)
@@ -1831,7 +1831,9 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
18311831
tokenizer=tokenizer,
18321832
)
18331833
datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])]
1834-
train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig)
1834+
train_dataset, _, _ = processor.process_dataset_configs(
1835+
dataset_configs=datasetconfig
1836+
)
18351837

18361838
assert isinstance(train_dataset, Dataset)
18371839
assert set(train_dataset.column_names) == column_names
@@ -1953,7 +1955,9 @@ def test_rename_and_select_dataset_columns(
19531955
name=datasetconfigname, data_paths=data_paths, data_handlers=handlers
19541956
)
19551957
]
1956-
train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig)
1958+
train_dataset, _, _ = processor.process_dataset_configs(
1959+
dataset_configs=datasetconfig
1960+
)
19571961

19581962
assert isinstance(train_dataset, Dataset)
19591963
assert set(train_dataset.column_names) == set(final)

tests/test_sft_trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2263,8 +2263,12 @@ def test_online_data_mixing_plugin_sample_training(
22632263
data = yaml.safe_load(f)
22642264
data["dataprocessor"]["odm"]["reward_type"] = reward_type
22652265
data["datasets"] = data["datasets"][:2]
2266+
sampling_weights = [0.4, 0.6]
2267+
i = 0
22662268
for d, df in zip(data["datasets"], datafiles):
22672269
d["data_paths"] = [df]
2270+
d["sampling"] = sampling_weights[i]
2271+
i += 1
22682272
yaml.dump(data, temp_yaml_file)
22692273
data_formatting_args.data_config_path = temp_yaml_file.name
22702274

@@ -2342,9 +2346,13 @@ def test_online_data_mixing_plugin_sample_training_no_validation_split(
23422346
data = yaml.safe_load(f)
23432347
data["datasets"] = data["datasets"][:2]
23442348
data["dataprocessor"]["odm"]["reward_type"] = reward_type
2349+
i = 0
2350+
sampling_weights = [0.4, 0.6]
23452351
for d, df in zip(data["datasets"], datafiles):
23462352
d["data_paths"] = [df]
2353+
d["sampling"] = sampling_weights[i]
23472354
del d["split"]
2355+
i += 1
23482356
yaml.dump(data, temp_yaml_file)
23492357
data_formatting_args.data_config_path = temp_yaml_file.name
23502358

tuning/config/acceleration_configs/odm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Standard
1616
from dataclasses import dataclass
17+
from typing import Union
1718

1819
# Local
1920
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass
@@ -27,6 +28,7 @@ class ODM:
2728
reward_type: str = None
2829
gamma: float = 0.1
2930
eta: float = 0.1
31+
resume_from_checkpoint: Union[bool, str] = False
3032

3133

3234
@dataclass

tuning/data/data_processors.py

Lines changed: 78 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -452,36 +452,9 @@ def split_dataset(
452452
)
453453
return split_datasets
454454

455-
def _process_datasets_for_odm(
456-
self,
457-
processed_datasets: List[
458-
Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]]
459-
],
460-
) -> Tuple[
461-
Dict[str, Union[Dataset, IterableDataset]],
462-
Dict[str, Union[Dataset, IterableDataset]],
463-
]:
464-
train_split = "train"
465-
eval_split = "test"
466-
train_datasets_dict = {}
467-
eval_datasets_dict = {}
468-
for d, raw in processed_datasets:
469-
if train_split in raw:
470-
train_datasets_dict[d.name] = raw[train_split]
471-
if eval_split in raw:
472-
eval_datasets_dict[d.name] = raw[eval_split]
473-
return train_datasets_dict, eval_datasets_dict
474-
475-
def _process_dataset_configs(
476-
self, dataset_configs: List[DataSetConfig], odm_config=None
477-
) -> Union[
478-
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
479-
Tuple[
480-
Dict[str, Union[Dataset, IterableDataset]],
481-
Dict[str, Union[Dataset, IterableDataset]],
482-
],
483-
]:
484-
455+
def _prepare_processed_datasets(
456+
self, dataset_configs: List[DataSetConfig]
457+
) -> List[Tuple[DataSetConfig, Union[IterableDataset, Dataset]]]:
485458
if not dataset_configs:
486459
raise ValueError(
487460
"No dataset configs provided. Provided Dataset configs is None."
@@ -530,13 +503,35 @@ def _process_dataset_configs(
530503

531504
# Append the processed datasets to the final dict
532505
processed_datasets.append((d, raw_datasets))
533-
if odm_config:
534-
logger.info(
535-
"Sampling probabilities are ignored if provided"
536-
"and are not used for concatenation. Instead"
537-
"online data mixing plugin handles it."
538-
)
539-
return self._process_datasets_for_odm(processed_datasets)
506+
return processed_datasets
507+
508+
def _validate_sampling_ratios(self, sampling_ratios: List[float], train_datasets):
509+
if len(sampling_ratios) > 0:
510+
if len(sampling_ratios) < len(train_datasets):
511+
raise ValueError(
512+
"Sampling probability should be specified for all datasets with train split"
513+
)
514+
if len(sampling_ratios) > len(train_datasets):
515+
raise ValueError(
516+
"Sampling probability should only be specified for datasets with train split"
517+
)
518+
if sum(p for p in sampling_ratios) != 1:
519+
raise ValueError(
520+
"Sampling probabilities for train datasets don't sum to 1"
521+
)
522+
return True
523+
524+
def _process_dataset_configs(
525+
self, dataset_configs: List[DataSetConfig]
526+
) -> Tuple[
527+
Union[Dataset, IterableDataset],
528+
Union[Dataset, IterableDataset],
529+
Dict[str, float],
530+
]:
531+
train_split = "train" # default
532+
eval_split = "test"
533+
processed_datasets = self._prepare_processed_datasets(dataset_configs)
534+
540535
train_datasets = []
541536
train_sampling_probabilities = []
542537
validation_datasets = []
@@ -557,25 +552,9 @@ def _process_dataset_configs(
557552
)
558553

559554
# quick check to see if we are sampling and if we need to throw error.
560-
if len(train_sampling_probabilities) > 0:
561-
if len(train_sampling_probabilities) < len(train_datasets):
562-
raise ValueError(
563-
"Sampling probability should be specified for all datasets with train split"
564-
)
565-
if len(train_sampling_probabilities) > len(train_datasets):
566-
raise ValueError(
567-
"Sampling probability should only be specified for datasets with train split"
568-
)
569-
if sum(p for p in train_sampling_probabilities) != 1:
570-
raise ValueError(
571-
"Sampling probabilities for train datasets don't sum to 1"
572-
)
573-
sample_datasets = True
574-
logger.info(
575-
"Sampling ratios are specified; only train datasets will be interleaved."
576-
)
577-
else:
578-
sample_datasets = False
555+
sample_datasets = self._validate_sampling_ratios(
556+
train_sampling_probabilities, train_datasets
557+
)
579558

580559
# Ensure again datasets are aligned before interleaving or concatenating
581560
maybe_align_datasets(train_datasets)
@@ -620,16 +599,14 @@ def _process_dataset_configs(
620599
if eval_dataset and isinstance(eval_dataset, IterableDataset):
621600
eval_dataset = resolve_iterable_dataset_features(eval_dataset)
622601

623-
return train_dataset, eval_dataset
602+
return train_dataset, eval_dataset, None
624603

625604
def process_dataset_configs(
626-
self, dataset_configs: List[DataSetConfig], odm_config=None
627-
) -> Union[
628-
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
629-
Tuple[
630-
Dict[str, Union[Dataset, IterableDataset]],
631-
Dict[str, Union[Dataset, IterableDataset]],
632-
],
605+
self, dataset_configs: List[DataSetConfig]
606+
) -> Tuple[
607+
Union[Dataset, IterableDataset],
608+
Union[Dataset, IterableDataset],
609+
Dict[str, float],
633610
]:
634611
train_dataset = eval_dataset = None
635612

@@ -643,14 +620,43 @@ def process_dataset_configs(
643620
# as we want to reuse HF cache and not redo computation on all nodes
644621
# For rationale see https://github.com/huggingface/trl/pull/3106
645622
with state.main_process_first():
646-
train_dataset, eval_dataset = self._process_dataset_configs(
647-
dataset_configs, odm_config
648-
)
623+
(
624+
train_dataset,
625+
eval_dataset,
626+
sampling_weights,
627+
) = self._process_dataset_configs(dataset_configs)
649628

650629
logger.info("Processed train dataset {}".format(train_dataset))
651630
logger.info("Processed eval dataset {}".format(eval_dataset))
652631

653-
return train_dataset, eval_dataset
632+
return train_dataset, eval_dataset, sampling_weights
633+
634+
635+
class ODMDataPreProcessor(DataPreProcessor):
636+
def _process_dataset_configs(
637+
self, dataset_configs: List[DataSetConfig]
638+
) -> Tuple[
639+
Dict[str, Union[Dataset, IterableDataset]],
640+
Dict[str, Union[Dataset, IterableDataset]],
641+
Dict[str, float],
642+
]:
643+
processed_datasets = self._prepare_processed_datasets(dataset_configs)
644+
train_split = "train"
645+
eval_split = "test"
646+
train_datasets_dict = {}
647+
eval_datasets_dict = {}
648+
sampling_weights_dict = {}
649+
for d, raw in processed_datasets:
650+
if d.sampling is not None and d.sampling > 0.0:
651+
sampling_weights_dict[d.name] = d.sampling
652+
if train_split in raw:
653+
train_datasets_dict[d.name] = raw[train_split]
654+
if eval_split in raw:
655+
eval_datasets_dict[d.name] = raw[eval_split]
656+
self._validate_sampling_ratios(
657+
sampling_weights_dict.values(), train_datasets_dict.values()
658+
)
659+
return train_datasets_dict, eval_datasets_dict, sampling_weights_dict
654660

655661

656662
def get_datapreprocessor(
@@ -659,7 +665,10 @@ def get_datapreprocessor(
659665
processor: AutoProcessor = None,
660666
additional_data_handlers: Dict[str, DataHandler] = None,
661667
) -> DataPreProcessor:
662-
data_processor = DataPreProcessor(
668+
data_processor_cls = DataPreProcessor
669+
if processor_config.type == "odm":
670+
data_processor_cls = ODMDataPreProcessor
671+
data_processor = data_processor_cls(
663672
processor_config=processor_config,
664673
tokenizer=tokenizer,
665674
processor=processor,

0 commit comments

Comments
 (0)