diff --git a/plugins/online-data-mixing/pyproject.toml b/plugins/online-data-mixing/pyproject.toml index bd97cce3..fee8e13f 100644 --- a/plugins/online-data-mixing/pyproject.toml +++ b/plugins/online-data-mixing/pyproject.toml @@ -22,7 +22,7 @@ classifiers=[ "Programming Language :: Python :: 3.11", ] -dependencies = ["datasets"] +dependencies = ["datasets", "torchdata"] [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration_odm"] diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py b/plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py index 8d6e919c..fcb6980a 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py @@ -14,6 +14,7 @@ # Local +from .callback import DataloaderSavingCallback from .framework_plugin_odm import OnlineDataMixingAccelerationPlugin from .odm import OnlineMixingDataset, Reward, compute_reward from .patch import patch_hf_trainer_evaluate diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/callback.py b/plugins/online-data-mixing/src/fms_acceleration_odm/callback.py new file mode 100644 index 00000000..f7430f80 --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/callback.py @@ -0,0 +1,38 @@ +# fms-hf-tuning patch +# Standard +from logging import getLogger +import os + +# Third Party +from transformers import TrainerCallback +import torch + +logger = getLogger(__name__) + + +class DataloaderSavingCallback(TrainerCallback): + def __init__(self, accelerator): + super().__init__() + self.accelerator = accelerator + + def on_save(self, args, state, control, **kwargs): + if not self.accelerator.is_main_process: + return + # Third Party + # pylint: disable=import-outside-toplevel + from torchdata.stateful_dataloader import StatefulDataLoader + + checkpoint_path = os.path.join( + args.output_dir, f"checkpoint-{state.global_step}" + ) + # It is assumed that one of the datasets would be stateful + # if stateful then it would be training dataset + for i, _ in enumerate(self.accelerator._dataloaders): + if isinstance( + self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader + ): + torch.save( + self.accelerator._dataloaders[i].state_dict(), + os.path.join(checkpoint_path, "odm_dl_state_dict.bin"), + ) + break diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py b/plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py index 9f1296a5..4fd4e40a 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py @@ -22,6 +22,7 @@ import torch # Local +from .callback import DataloaderSavingCallback from .patch import patch_hf_trainer_evaluate @@ -36,6 +37,11 @@ def __init__(self, configurations: Dict[str, Dict]): default=1, ) + self._resume_from_checkpoint = self._check_config_and_maybe_check_values( + key="training.odm.odm.resume_from_checkpoint", + default=False, + ) + # data_config file should be there @property def requires_augmentation(self): @@ -55,15 +61,18 @@ def augmentation( train_args.eval_steps = 1 train_args.eval_strategy = "steps" - # update_interval information has to be made available in the evaluate HF patch - # function and this seems to be the only reasonable way to do so + # update_interval and resume_from_checkpoint information has to be made + # available in the evaluate HF patch function and this seems to be + # the only reasonable way to do so model.ta_update_interval = self._update_interval + model.resume_from_checkpoint = self._resume_from_checkpoint + return model, modifiable_args def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator=None ): - callbacks = [] + callbacks = [DataloaderSavingCallback(accelerator)] patch_hf_trainer_evaluate() return callbacks diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py index 5d01f2bc..54cdf395 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py @@ -9,6 +9,7 @@ # Third Party from datasets import DatasetDict from torch.utils.data import DataLoader, IterableDataset +from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm import torch @@ -97,20 +98,23 @@ def __init__( self.eval_collators_dict = eval_collators_dict self.eval_dataset_dict = eval_dataset_dict self.eval_dataset_dict_dl = {} + # iterators of the dataloaders + self.train_dataset_dict_dl_iter = {} + # to reset iterators to dataloaders self.train_dataset_dict_dl = {} + self.dataset_dict = dataset_dict # prepare torch dataloaders for each of the dataset. - for k, _ in dataset_dict.items(): - dataset_dict[k] = DataLoader( - dataset_dict[k], + for k, _ in self.dataset_dict.items(): + self.train_dataset_dict_dl[k] = StatefulDataLoader( + self.dataset_dict[k], 1, shuffle=False, - num_workers=1, + num_workers=0, collate_fn=collators_dict[k] if collators_dict else None, ) - self.train_dataset_dict_dl[k] = iter(dataset_dict[k]) + self.train_dataset_dict_dl_iter[k] = iter(self.train_dataset_dict_dl[k]) self.eval_batch_size = eval_batch_size - self.dataset_dict = dataset_dict - self.category_list = sorted(self.train_dataset_dict_dl.keys()) + self.category_list = sorted(self.train_dataset_dict_dl_iter.keys()) self.id2cat = dict(enumerate(self.category_list)) self.cat2id = {c: i for i, c in enumerate(self.category_list)} self.total_categories = len(self.category_list) @@ -172,7 +176,6 @@ def log_to_file(self, data: dict): f.write(json.dumps(self.log) + "\n") def __iter__(self): - self.produced = 0 return self def __next__(self): @@ -182,17 +185,17 @@ def __next__(self): )[0] sample = None try: - sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]]) + sample = next(self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]]) except StopIteration: logger.info( "{id} dataset exhausted so the iterator is reset.".format( id=self.id2cat[self.arm_idx] ) ) - self.train_dataset_dict_dl[self.id2cat[self.arm_idx]] = iter( - self.dataset_dict[self.id2cat[self.arm_idx]] + self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]] = iter( + self.train_dataset_dict_dl[self.id2cat[self.arm_idx]] ) - sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]]) + sample = next(self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]]) self.curr_cat_count[self.arm_idx] += 1 self.produced += 1 @@ -231,6 +234,44 @@ def __next__(self): ) return sample + def load_state_dict(self, state_dict): + """Load the dataloader with the provided state dict""" + torch.set_rng_state(state_dict["rng"]) + train_dataset_dict_dl_sd = state_dict.pop("train_dataset_dict_dl_sd") + random.setstate(state_dict.pop("random_state")) + for k, v in state_dict.items(): + if hasattr(self, k): + setattr(self, k, v) + self.reward_type = Reward[state_dict["reward_type"].upper()] + for k, _ in train_dataset_dict_dl_sd.items(): + self.train_dataset_dict_dl_iter[k].load_state_dict( + train_dataset_dict_dl_sd[k] + ) + + def state_dict(self): + """Populate all the state that has to be stored by the stateful dataloader""" + return { + "rng": torch.get_rng_state(), + "gamma": self.gamma, + "eta": self.eta, + "sampling_interval": self.sampling_interval, + "train_dataset_dict_dl_sd": { + k: v.state_dict() for k, v in self.train_dataset_dict_dl_iter.items() + }, + "eval_batch_size": self.eval_batch_size, + "category_list": self.category_list, + "id2cat": self.id2cat, + "cat2id": self.cat2id, + "total_categories": self.total_categories, + "sampling_weights": self.sampling_weights, + "sampling_ratio": self.sampling_ratio, + "curr_cat_count": self.curr_cat_count, + "produced": self.produced, + "arm_idx": self.arm_idx, + "reward_type": str(self.reward_type), + "random_state": random.getstate(), + } + def _reset_eval_dataloaders(self): """Helper function to reset eval dataloaders since they would be exhausted in the previous evaluation loop. @@ -244,8 +285,8 @@ def _reset_eval_dataloaders(self): DataLoader( self.eval_dataset_dict[k], self.eval_batch_size, - shuffle=False, - num_workers=1, + shuffle=True, + num_workers=0, collate_fn=( self.eval_collators_dict[k] if self.eval_collators_dict @@ -398,14 +439,14 @@ def update_sampling_weights(self, model, accelerator, state): if accelerator: rewards = accelerator.reduce(rewards, reduction="sum") count = accelerator.reduce(count, reduction="sum") - if accelerator.is_main_process: + if accelerator and accelerator.is_main_process: self._update_weights(count, rewards) - self.log_to_file( - { - "current_sampling_weights": self.sampling_weights.tolist(), - "current_sampling_ratio": self.sampling_ratio, - "rewards": rewards.tolist(), - "count": count.tolist(), - "action": "update", - } - ) + self.log_to_file( + { + "current_sampling_weights": self.sampling_weights.tolist(), + "current_sampling_ratio": self.sampling_ratio, + "rewards": rewards.tolist(), + "count": count.tolist(), + "action": "update", + } + ) diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/patch.py b/plugins/online-data-mixing/src/fms_acceleration_odm/patch.py index 7e52cec1..a8d40e23 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/patch.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/patch.py @@ -1,9 +1,7 @@ # fms-hf-tuning patch # Standard from logging import getLogger - -# Third Party -from transformers import Trainer +import os logger = getLogger(__name__) @@ -12,11 +10,16 @@ def patch_hf_trainer_evaluate(): # Third Party # pylint: disable=import-outside-toplevel from fms_acceleration.model_patcher import patch_target_module + from transformers import Trainer Trainer._evaluate = _evaluate + Trainer._get_dataloader = _get_dataloader + Trainer.get_train_dataloader = get_train_dataloader patch_target_module("transformers.trainer.Trainer", Trainer) + patch_target_module("transformers.trainer.skip_first_batches", skip_first_batches) +# code taken from transformers, modified and patches original function def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): # Standard # pylint: disable=import-outside-toplevel @@ -105,3 +108,126 @@ def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): self.train_dataset.update_sampling_weights(model, self.accelerator, self.state) return metrics + + +# code taken from transformers, modified and patches original function +def _get_dataloader( + self, + dataset, + description, + batch_size, + sampler_fn=None, + is_training=False, + dataloader_key=None, +): + """Create a [`~torch.utils.data.DataLoader`] from the given dataset.""" + # Standard + # pylint: disable=import-outside-toplevel + from functools import partial + + # Third Party + # pylint: disable=import-outside-toplevel + from torch.utils.data import DataLoader + from torchdata.stateful_dataloader import StatefulDataLoader + from transformers import is_datasets_available + from transformers.trainer_utils import seed_worker + import torch + + if is_datasets_available(): + # Third Party + # pylint: disable=import-outside-toplevel + import datasets + + data_collator = self.data_collator + if is_datasets_available() and isinstance(dataset, datasets.Dataset): + dataset = self._remove_unused_columns(dataset, description=description) + else: + data_collator = self._get_collator_with_removed_columns( + self.data_collator, description=description + ) + + dataloader_params = { + "batch_size": batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(dataset, torch.utils.data.IterableDataset): + if sampler_fn is not None: + dataloader_params["sampler"] = sampler_fn(dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + if is_training: + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + if is_training: + self.accelerator.dataloader_config.use_stateful_dataloader = True + dataloader = self.accelerator.prepare( + StatefulDataLoader(dataset, **dataloader_params) + ) + else: + dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params)) + + # Store the prepared dataloader for subsequent evaluations if using persistent workers. + if dataloader_key is not None and self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = dataloader + else: + self._eval_dataloaders = {dataloader_key: dataloader} + + return dataloader + + +# code taken from transformers, modified and patches original function +def get_train_dataloader(self): + # Third Party + # pylint: disable=import-outside-toplevel + from torchdata.stateful_dataloader import StatefulDataLoader + from transformers.trainer_utils import get_last_checkpoint + import torch + + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + dataloader = self._get_dataloader( + dataset=self.train_dataset, + description="Training", + batch_size=self._train_batch_size, + sampler_fn=self._get_train_sampler, + is_training=True, + ) + resume_from_checkpoint = self.model.resume_from_checkpoint + if resume_from_checkpoint: + # code taken from transformers and modified + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) + if resume_from_checkpoint is None: + raise ValueError( + f"No valid checkpoint found in output directory ({self.args.output_dir})" + ) + self.model.resume_from_checkpoint = resume_from_checkpoint + + # load state to the dataloader + dataloader_state_dict_name = "odm_dl_state_dict.bin" + output_dataloader_state_dict_file = os.path.join( + resume_from_checkpoint, dataloader_state_dict_name + ) + for i, _ in enumerate(self.accelerator._dataloaders): + if isinstance( + self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader + ): + self.accelerator._dataloaders[i].load_state_dict( + torch.load(output_dataloader_state_dict_file) + ) + break + return dataloader + + +# code taken from transformers, modified and patches original function +def skip_first_batches(dataloader, num_batches=0): + return dataloader