Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ datasets:
split:
train: 0.8
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
sampling: 0.3 # ignored
sampling: 0.3 # used as starting weights for online data mixing
data_paths:
- "FILE_PATH"
data_handlers:
Expand All @@ -28,7 +28,7 @@ datasets:
split:
train: 0.6
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
sampling: 0.4 # ignored
sampling: 0.4 # used as starting weights for online data mixing
data_paths:
- "FILE_PATH"
data_handlers:
Expand All @@ -43,7 +43,7 @@ datasets:
split:
train: 0.4
validation: 0.1 # validation set is also used in ODM reward computation when reward_type is validation_loss.
sampling: 0.3 # ignored
sampling: 0.3 # used as starting weights for online data mixing
data_paths:
- "FILE_PATH"
data_handlers:
Expand Down
18 changes: 11 additions & 7 deletions tests/data/test_data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def test_process_dataconfig_file_with_streaming(data_config_path, data_path):
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, IterableDataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
Expand Down Expand Up @@ -1017,7 +1017,7 @@ def test_process_dataconfig_file(data_config_path, data_path):
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
Expand Down Expand Up @@ -1107,7 +1107,7 @@ def test_process_datahandler_eos_token(data_config_path, data_path, add_eos_toke
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
Expand Down Expand Up @@ -1258,7 +1258,7 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list):
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
Expand Down Expand Up @@ -1330,7 +1330,7 @@ def test_process_dataconfig_multiple_files_folders_with_globbing(
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
(train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, Dataset)
assert set(["input_ids", "attention_mask", "labels"]).issubset(
set(train_set.column_names)
Expand Down Expand Up @@ -1831,7 +1831,9 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
tokenizer=tokenizer,
)
datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])]
train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig)
train_dataset, _, _ = processor.process_dataset_configs(
dataset_configs=datasetconfig
)

assert isinstance(train_dataset, Dataset)
assert set(train_dataset.column_names) == column_names
Expand Down Expand Up @@ -1953,7 +1955,9 @@ def test_rename_and_select_dataset_columns(
name=datasetconfigname, data_paths=data_paths, data_handlers=handlers
)
]
train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig)
train_dataset, _, _ = processor.process_dataset_configs(
dataset_configs=datasetconfig
)

assert isinstance(train_dataset, Dataset)
assert set(train_dataset.column_names) == set(final)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,8 +2263,12 @@ def test_online_data_mixing_plugin_sample_training(
data = yaml.safe_load(f)
data["dataprocessor"]["odm"]["reward_type"] = reward_type
data["datasets"] = data["datasets"][:2]
sampling_weights = [0.4, 0.6]
i = 0
for d, df in zip(data["datasets"], datafiles):
d["data_paths"] = [df]
d["sampling"] = sampling_weights[i]
i += 1
yaml.dump(data, temp_yaml_file)
data_formatting_args.data_config_path = temp_yaml_file.name

Expand Down Expand Up @@ -2342,9 +2346,13 @@ def test_online_data_mixing_plugin_sample_training_no_validation_split(
data = yaml.safe_load(f)
data["datasets"] = data["datasets"][:2]
data["dataprocessor"]["odm"]["reward_type"] = reward_type
i = 0
sampling_weights = [0.4, 0.6]
for d, df in zip(data["datasets"], datafiles):
d["data_paths"] = [df]
d["sampling"] = sampling_weights[i]
del d["split"]
i += 1
yaml.dump(data, temp_yaml_file)
data_formatting_args.data_config_path = temp_yaml_file.name

Expand Down
2 changes: 2 additions & 0 deletions tuning/config/acceleration_configs/odm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from dataclasses import dataclass
from typing import Union

# Local
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass
Expand All @@ -27,6 +28,7 @@ class ODM:
reward_type: str = None
gamma: float = 0.1
eta: float = 0.1
resume_from_checkpoint: Union[bool, str] = False


@dataclass
Expand Down
147 changes: 78 additions & 69 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,36 +452,9 @@ def split_dataset(
)
return split_datasets

def _process_datasets_for_odm(
self,
processed_datasets: List[
Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]]
],
) -> Tuple[
Dict[str, Union[Dataset, IterableDataset]],
Dict[str, Union[Dataset, IterableDataset]],
]:
train_split = "train"
eval_split = "test"
train_datasets_dict = {}
eval_datasets_dict = {}
for d, raw in processed_datasets:
if train_split in raw:
train_datasets_dict[d.name] = raw[train_split]
if eval_split in raw:
eval_datasets_dict[d.name] = raw[eval_split]
return train_datasets_dict, eval_datasets_dict

def _process_dataset_configs(
self, dataset_configs: List[DataSetConfig], odm_config=None
) -> Union[
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
Tuple[
Dict[str, Union[Dataset, IterableDataset]],
Dict[str, Union[Dataset, IterableDataset]],
],
]:

def _prepare_processed_datasets(
self, dataset_configs: List[DataSetConfig]
) -> List[Tuple[DataSetConfig, Union[IterableDataset, Dataset]]]:
if not dataset_configs:
raise ValueError(
"No dataset configs provided. Provided Dataset configs is None."
Expand Down Expand Up @@ -530,13 +503,35 @@ def _process_dataset_configs(

# Append the processed datasets to the final dict
processed_datasets.append((d, raw_datasets))
if odm_config:
logger.info(
"Sampling probabilities are ignored if provided"
"and are not used for concatenation. Instead"
"online data mixing plugin handles it."
)
return self._process_datasets_for_odm(processed_datasets)
return processed_datasets

def _validate_sampling_ratios(self, sampling_ratios: List[float], train_datasets):
if len(sampling_ratios) > 0:
if len(sampling_ratios) < len(train_datasets):
raise ValueError(
"Sampling probability should be specified for all datasets with train split"
)
if len(sampling_ratios) > len(train_datasets):
raise ValueError(
"Sampling probability should only be specified for datasets with train split"
)
if sum(p for p in sampling_ratios) != 1:
raise ValueError(
"Sampling probabilities for train datasets don't sum to 1"
)
return True

def _process_dataset_configs(
self, dataset_configs: List[DataSetConfig]
) -> Tuple[
Union[Dataset, IterableDataset],
Union[Dataset, IterableDataset],
Dict[str, float],
]:
train_split = "train" # default
eval_split = "test"
processed_datasets = self._prepare_processed_datasets(dataset_configs)

train_datasets = []
train_sampling_probabilities = []
validation_datasets = []
Expand All @@ -557,25 +552,9 @@ def _process_dataset_configs(
)

# quick check to see if we are sampling and if we need to throw error.
if len(train_sampling_probabilities) > 0:
if len(train_sampling_probabilities) < len(train_datasets):
raise ValueError(
"Sampling probability should be specified for all datasets with train split"
)
if len(train_sampling_probabilities) > len(train_datasets):
raise ValueError(
"Sampling probability should only be specified for datasets with train split"
)
if sum(p for p in train_sampling_probabilities) != 1:
raise ValueError(
"Sampling probabilities for train datasets don't sum to 1"
)
sample_datasets = True
logger.info(
"Sampling ratios are specified; only train datasets will be interleaved."
)
else:
sample_datasets = False
sample_datasets = self._validate_sampling_ratios(
train_sampling_probabilities, train_datasets
)

# Ensure again datasets are aligned before interleaving or concatenating
maybe_align_datasets(train_datasets)
Expand Down Expand Up @@ -620,16 +599,14 @@ def _process_dataset_configs(
if eval_dataset and isinstance(eval_dataset, IterableDataset):
eval_dataset = resolve_iterable_dataset_features(eval_dataset)

return train_dataset, eval_dataset
return train_dataset, eval_dataset, None

def process_dataset_configs(
self, dataset_configs: List[DataSetConfig], odm_config=None
) -> Union[
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
Tuple[
Dict[str, Union[Dataset, IterableDataset]],
Dict[str, Union[Dataset, IterableDataset]],
],
self, dataset_configs: List[DataSetConfig]
) -> Tuple[
Union[Dataset, IterableDataset],
Union[Dataset, IterableDataset],
Dict[str, float],
]:
train_dataset = eval_dataset = None

Expand All @@ -643,14 +620,43 @@ def process_dataset_configs(
# as we want to reuse HF cache and not redo computation on all nodes
# For rationale see https://github.com/huggingface/trl/pull/3106
with state.main_process_first():
train_dataset, eval_dataset = self._process_dataset_configs(
dataset_configs, odm_config
)
(
train_dataset,
eval_dataset,
sampling_weights,
) = self._process_dataset_configs(dataset_configs)

logger.info("Processed train dataset {}".format(train_dataset))
logger.info("Processed eval dataset {}".format(eval_dataset))

return train_dataset, eval_dataset
return train_dataset, eval_dataset, sampling_weights


class ODMDataPreProcessor(DataPreProcessor):
def _process_dataset_configs(
self, dataset_configs: List[DataSetConfig]
) -> Tuple[
Dict[str, Union[Dataset, IterableDataset]],
Dict[str, Union[Dataset, IterableDataset]],
Dict[str, float],
]:
processed_datasets = self._prepare_processed_datasets(dataset_configs)
train_split = "train"
eval_split = "test"
train_datasets_dict = {}
eval_datasets_dict = {}
sampling_weights_dict = {}
for d, raw in processed_datasets:
if d.sampling is not None and d.sampling > 0.0:
sampling_weights_dict[d.name] = d.sampling
if train_split in raw:
train_datasets_dict[d.name] = raw[train_split]
if eval_split in raw:
eval_datasets_dict[d.name] = raw[eval_split]
self._validate_sampling_ratios(
sampling_weights_dict.values(), train_datasets_dict.values()
)
return train_datasets_dict, eval_datasets_dict, sampling_weights_dict


def get_datapreprocessor(
Expand All @@ -659,7 +665,10 @@ def get_datapreprocessor(
processor: AutoProcessor = None,
additional_data_handlers: Dict[str, DataHandler] = None,
) -> DataPreProcessor:
data_processor = DataPreProcessor(
data_processor_cls = DataPreProcessor
if processor_config.type == "odm":
data_processor_cls = ODMDataPreProcessor
data_processor = data_processor_cls(
processor_config=processor_config,
tokenizer=tokenizer,
processor=processor,
Expand Down
Loading