diff --git a/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml b/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml index 6b5bda698..74c72a5fd 100644 --- a/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml +++ b/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml @@ -13,7 +13,7 @@ datasets: split: train: 0.8 validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss. - sampling: 0.3 # ignored + sampling: 0.3 # used as starting weights for online data mixing data_paths: - "FILE_PATH" data_handlers: @@ -28,7 +28,7 @@ datasets: split: train: 0.6 validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss. - sampling: 0.4 # ignored + sampling: 0.4 # used as starting weights for online data mixing data_paths: - "FILE_PATH" data_handlers: @@ -43,7 +43,7 @@ datasets: split: train: 0.4 validation: 0.1 # validation set is also used in ODM reward computation when reward_type is validation_loss. - sampling: 0.3 # ignored + sampling: 0.3 # used as starting weights for online data mixing data_paths: - "FILE_PATH" data_handlers: diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index 1ddfaaf89..a1072d2ec 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -768,7 +768,7 @@ def test_process_dataconfig_file_with_streaming(data_config_path, data_path): output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, IterableDataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1017,7 +1017,7 @@ def test_process_dataconfig_file(data_config_path, data_path): output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1107,7 +1107,7 @@ def test_process_datahandler_eos_token(data_config_path, data_path, add_eos_toke output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1258,7 +1258,7 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list): output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1330,7 +1330,7 @@ def test_process_dataconfig_multiple_files_folders_with_globbing( output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) assert set(["input_ids", "attention_mask", "labels"]).issubset( set(train_set.column_names) @@ -1831,7 +1831,9 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname): tokenizer=tokenizer, ) datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])] - train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig) + train_dataset, _, _ = processor.process_dataset_configs( + dataset_configs=datasetconfig + ) assert isinstance(train_dataset, Dataset) assert set(train_dataset.column_names) == column_names @@ -1953,7 +1955,9 @@ def test_rename_and_select_dataset_columns( name=datasetconfigname, data_paths=data_paths, data_handlers=handlers ) ] - train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig) + train_dataset, _, _ = processor.process_dataset_configs( + dataset_configs=datasetconfig + ) assert isinstance(train_dataset, Dataset) assert set(train_dataset.column_names) == set(final) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index aee20293e..f9f1d3810 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -2263,8 +2263,12 @@ def test_online_data_mixing_plugin_sample_training( data = yaml.safe_load(f) data["dataprocessor"]["odm"]["reward_type"] = reward_type data["datasets"] = data["datasets"][:2] + sampling_weights = [0.4, 0.6] + i = 0 for d, df in zip(data["datasets"], datafiles): d["data_paths"] = [df] + d["sampling"] = sampling_weights[i] + i += 1 yaml.dump(data, temp_yaml_file) data_formatting_args.data_config_path = temp_yaml_file.name @@ -2342,9 +2346,13 @@ def test_online_data_mixing_plugin_sample_training_no_validation_split( data = yaml.safe_load(f) data["datasets"] = data["datasets"][:2] data["dataprocessor"]["odm"]["reward_type"] = reward_type + i = 0 + sampling_weights = [0.4, 0.6] for d, df in zip(data["datasets"], datafiles): d["data_paths"] = [df] + d["sampling"] = sampling_weights[i] del d["split"] + i += 1 yaml.dump(data, temp_yaml_file) data_formatting_args.data_config_path = temp_yaml_file.name diff --git a/tuning/config/acceleration_configs/odm.py b/tuning/config/acceleration_configs/odm.py index f5c2d9f8c..497fc6048 100644 --- a/tuning/config/acceleration_configs/odm.py +++ b/tuning/config/acceleration_configs/odm.py @@ -14,6 +14,7 @@ # Standard from dataclasses import dataclass +from typing import Union # Local from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass @@ -27,6 +28,7 @@ class ODM: reward_type: str = None gamma: float = 0.1 eta: float = 0.1 + resume_from_checkpoint: Union[bool, str] = False @dataclass diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index f5cc4b672..20936f387 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -452,36 +452,9 @@ def split_dataset( ) return split_datasets - def _process_datasets_for_odm( - self, - processed_datasets: List[ - Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]] - ], - ) -> Tuple[ - Dict[str, Union[Dataset, IterableDataset]], - Dict[str, Union[Dataset, IterableDataset]], - ]: - train_split = "train" - eval_split = "test" - train_datasets_dict = {} - eval_datasets_dict = {} - for d, raw in processed_datasets: - if train_split in raw: - train_datasets_dict[d.name] = raw[train_split] - if eval_split in raw: - eval_datasets_dict[d.name] = raw[eval_split] - return train_datasets_dict, eval_datasets_dict - - def _process_dataset_configs( - self, dataset_configs: List[DataSetConfig], odm_config=None - ) -> Union[ - Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]], - Tuple[ - Dict[str, Union[Dataset, IterableDataset]], - Dict[str, Union[Dataset, IterableDataset]], - ], - ]: - + def _prepare_processed_datasets( + self, dataset_configs: List[DataSetConfig] + ) -> List[Tuple[DataSetConfig, Union[IterableDataset, Dataset]]]: if not dataset_configs: raise ValueError( "No dataset configs provided. Provided Dataset configs is None." @@ -530,13 +503,35 @@ def _process_dataset_configs( # Append the processed datasets to the final dict processed_datasets.append((d, raw_datasets)) - if odm_config: - logger.info( - "Sampling probabilities are ignored if provided" - "and are not used for concatenation. Instead" - "online data mixing plugin handles it." - ) - return self._process_datasets_for_odm(processed_datasets) + return processed_datasets + + def _validate_sampling_ratios(self, sampling_ratios: List[float], train_datasets): + if len(sampling_ratios) > 0: + if len(sampling_ratios) < len(train_datasets): + raise ValueError( + "Sampling probability should be specified for all datasets with train split" + ) + if len(sampling_ratios) > len(train_datasets): + raise ValueError( + "Sampling probability should only be specified for datasets with train split" + ) + if sum(p for p in sampling_ratios) != 1: + raise ValueError( + "Sampling probabilities for train datasets don't sum to 1" + ) + return True + + def _process_dataset_configs( + self, dataset_configs: List[DataSetConfig] + ) -> Tuple[ + Union[Dataset, IterableDataset], + Union[Dataset, IterableDataset], + Dict[str, float], + ]: + train_split = "train" # default + eval_split = "test" + processed_datasets = self._prepare_processed_datasets(dataset_configs) + train_datasets = [] train_sampling_probabilities = [] validation_datasets = [] @@ -557,25 +552,9 @@ def _process_dataset_configs( ) # quick check to see if we are sampling and if we need to throw error. - if len(train_sampling_probabilities) > 0: - if len(train_sampling_probabilities) < len(train_datasets): - raise ValueError( - "Sampling probability should be specified for all datasets with train split" - ) - if len(train_sampling_probabilities) > len(train_datasets): - raise ValueError( - "Sampling probability should only be specified for datasets with train split" - ) - if sum(p for p in train_sampling_probabilities) != 1: - raise ValueError( - "Sampling probabilities for train datasets don't sum to 1" - ) - sample_datasets = True - logger.info( - "Sampling ratios are specified; only train datasets will be interleaved." - ) - else: - sample_datasets = False + sample_datasets = self._validate_sampling_ratios( + train_sampling_probabilities, train_datasets + ) # Ensure again datasets are aligned before interleaving or concatenating maybe_align_datasets(train_datasets) @@ -620,16 +599,14 @@ def _process_dataset_configs( if eval_dataset and isinstance(eval_dataset, IterableDataset): eval_dataset = resolve_iterable_dataset_features(eval_dataset) - return train_dataset, eval_dataset + return train_dataset, eval_dataset, None def process_dataset_configs( - self, dataset_configs: List[DataSetConfig], odm_config=None - ) -> Union[ - Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]], - Tuple[ - Dict[str, Union[Dataset, IterableDataset]], - Dict[str, Union[Dataset, IterableDataset]], - ], + self, dataset_configs: List[DataSetConfig] + ) -> Tuple[ + Union[Dataset, IterableDataset], + Union[Dataset, IterableDataset], + Dict[str, float], ]: train_dataset = eval_dataset = None @@ -643,14 +620,43 @@ def process_dataset_configs( # as we want to reuse HF cache and not redo computation on all nodes # For rationale see https://github.com/huggingface/trl/pull/3106 with state.main_process_first(): - train_dataset, eval_dataset = self._process_dataset_configs( - dataset_configs, odm_config - ) + ( + train_dataset, + eval_dataset, + sampling_weights, + ) = self._process_dataset_configs(dataset_configs) logger.info("Processed train dataset {}".format(train_dataset)) logger.info("Processed eval dataset {}".format(eval_dataset)) - return train_dataset, eval_dataset + return train_dataset, eval_dataset, sampling_weights + + +class ODMDataPreProcessor(DataPreProcessor): + def _process_dataset_configs( + self, dataset_configs: List[DataSetConfig] + ) -> Tuple[ + Dict[str, Union[Dataset, IterableDataset]], + Dict[str, Union[Dataset, IterableDataset]], + Dict[str, float], + ]: + processed_datasets = self._prepare_processed_datasets(dataset_configs) + train_split = "train" + eval_split = "test" + train_datasets_dict = {} + eval_datasets_dict = {} + sampling_weights_dict = {} + for d, raw in processed_datasets: + if d.sampling is not None and d.sampling > 0.0: + sampling_weights_dict[d.name] = d.sampling + if train_split in raw: + train_datasets_dict[d.name] = raw[train_split] + if eval_split in raw: + eval_datasets_dict[d.name] = raw[eval_split] + self._validate_sampling_ratios( + sampling_weights_dict.values(), train_datasets_dict.values() + ) + return train_datasets_dict, eval_datasets_dict, sampling_weights_dict def get_datapreprocessor( @@ -659,7 +665,10 @@ def get_datapreprocessor( processor: AutoProcessor = None, additional_data_handlers: Dict[str, DataHandler] = None, ) -> DataPreProcessor: - data_processor = DataPreProcessor( + data_processor_cls = DataPreProcessor + if processor_config.type == "odm": + data_processor_cls = ODMDataPreProcessor + data_processor = data_processor_cls( processor_config=processor_config, tokenizer=tokenizer, processor=processor, diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 589c7a791..bc2f84315 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -14,7 +14,7 @@ # Standard from pathlib import Path -from typing import Dict, Union +from typing import Dict, List, Union import logging # Third Party @@ -73,7 +73,6 @@ def process_dataconfig_file( processor: AutoProcessor = None, is_multipack: bool = False, is_padding_free: bool = False, - odm_config: ODMConfig = None, ): """ Args: @@ -90,11 +89,12 @@ def process_dataconfig_file( is_multipack: A bool representing is Multipack plugin is enabled. Defauts to False. Returns: - Tuple(Dataset, Dataset, str) + Tuple(Dataset, Dataset, str, Dict[str, float]) tuple containing train_dataset (Dataset/IterableDataset), eval_dataset (Dataset/IterableDataset), dataset_text_field (str), + sampling weights """ data_config = load_and_validate_data_config(data_args.data_config_path) @@ -154,11 +154,13 @@ def process_dataconfig_file( ) tokenizer.chat_template = data_processor.processor_config.chat_template - train_dataset, eval_dataset = data_processor.process_dataset_configs( - data_config.datasets, odm_config=odm_config - ) + ( + train_dataset, + eval_dataset, + sampling_weights, + ) = data_processor.process_dataset_configs(data_config.datasets) - return (train_dataset, eval_dataset, data_args.dataset_text_field) + return (train_dataset, eval_dataset, data_args.dataset_text_field, sampling_weights) # Data Format 1: Pretokenized Data @@ -348,7 +350,6 @@ def _process_raw_data_args( additional_data_handlers: Dict[str, DataHandler] = None, is_padding_free: bool = False, processor: AutoProcessor = None, - odm_config: ODMConfig = None, ): if data_args.data_config_path is not None: @@ -447,11 +448,13 @@ def _process_raw_data_args( if is_eval_dataset_present: dataset_configs.append(eval_dataset_config) - train_dataset, eval_dataset = data_processor.process_dataset_configs( - dataset_configs, odm_config=odm_config - ) + ( + train_dataset, + eval_dataset, + sampling_weights, + ) = data_processor.process_dataset_configs(dataset_configs) - return (train_dataset, eval_dataset, dataset_text_field) + return (train_dataset, eval_dataset, dataset_text_field, sampling_weights) def dump_dataset( @@ -505,6 +508,7 @@ def setup_train_dataset_for_odm( train_dataset: Dict = None, reward_dataset: Dict = None, # eval_dataset is used for reward computation max_seq_length: str = None, + sampling_weights: List[float] = None, # cold start sampling weights for ODM ): # pylint: disable=import-outside-toplevel if not is_fms_accelerate_available(plugins="odm"): @@ -549,7 +553,7 @@ def setup_train_dataset_for_odm( collators, reward_dataset, eval_collators, - None, + sampling_weights, gamma=odm_config.odm.gamma, eta=odm_config.odm.eta, output_dir=train_args.output_dir, @@ -627,7 +631,12 @@ def process_dataargs( ) if data_args.data_config_path: - train_dataset, eval_dataset, dataset_text_field = process_dataconfig_file( + ( + train_dataset, + eval_dataset, + dataset_text_field, + sampling_weights, + ) = process_dataconfig_file( data_args, train_args, tokenizer, @@ -635,10 +644,14 @@ def process_dataargs( processor, is_multipack, is_padding_free, - odm_config=odm_config, ) else: - train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( + ( + train_dataset, + eval_dataset, + dataset_text_field, + sampling_weights, + ) = _process_raw_data_args( data_args, tokenizer, train_args.packing, @@ -646,7 +659,6 @@ def process_dataargs( additional_data_handlers, is_padding_free, processor, - odm_config=odm_config, ) if train_args.eval_strategy != "no" and eval_dataset is None: @@ -700,6 +712,7 @@ def process_dataargs( train_dataset, eval_dataset, max_seq_length, + sampling_weights, ) else: # Note: This check should not be removed. diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 04b920c8a..330606946 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -129,12 +129,36 @@ def train( logger_name="sft_trainer_train", level=train_args.log_level ) + resume_from_checkpoint = None + if train_args.output_dir: + os.makedirs(train_args.output_dir, exist_ok=True) + logger.info("using the output directory at %s", train_args.output_dir) + + # Check if resume flag is not passed (None), or if flag is true and + # output_dir has checkpoints then get last checkpoint from output_dir + if ( + train_args.resume_from_checkpoint is None + or train_args.resume_from_checkpoint.lower() == "true" + ): + resume_from_checkpoint = get_last_checkpoint(train_args.output_dir) + else: + # `train_args.resume_from_checkpoint` gives string values + # Check if flag is false OR flag has checkpoint value for resuming tuning + resume_from_checkpoint = ( + train_args.resume_from_checkpoint + if train_args.resume_from_checkpoint.lower() != "false" + else False + ) + # TODO: use of load_and_validate_data_config here is not clean way # rather we should move this logic to process_dataargs odm_config = None if data_args.data_config_path: _dataconfig = load_and_validate_data_config(data_args.data_config_path) if _dataconfig.dataprocessor.type == "odm": + _dataconfig.dataprocessor.odm[ + "resume_from_checkpoint" + ] = resume_from_checkpoint odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm)) USE_ALORA = False @@ -504,23 +528,6 @@ def train( ): trainer.add_callback(clb) - resume_from_checkpoint = None - # Check if resume flag is not passed (None), or if flag is true and - # output_dir has checkpoints then get last checkpoint from output_dir - if ( - training_args.resume_from_checkpoint is None - or training_args.resume_from_checkpoint.lower() == "true" - ): - resume_from_checkpoint = get_last_checkpoint(training_args.output_dir) - else: - # `training_args.resume_from_checkpoint` gives string values - # Check if flag is false OR flag has checkpoint value for resuming tuning - resume_from_checkpoint = ( - training_args.resume_from_checkpoint - if training_args.resume_from_checkpoint.lower() != "false" - else False - ) - trainer.train(resume_from_checkpoint) additional_metadata = {} additional_metadata["added_tokens_info"] = added_tokens_dict @@ -794,9 +801,6 @@ def main(): "failed while parsing extra metadata. pass a valid json %s", repr(e) ) - if training_args.output_dir: - os.makedirs(training_args.output_dir, exist_ok=True) - logger.info("using the output directory at %s", training_args.output_dir) try: trainer, additional_train_info, tc_callback = train( model_args=model_args,