Skip to content

Commit 30c2c55

Browse files
authored
feat: Resume functionality for online data mixing (#155)
* feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionalityclear Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> * feat: resume functionality Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 767bc4b commit 30c2c55

File tree

6 files changed

+246
-31
lines changed

6 files changed

+246
-31
lines changed

plugins/online-data-mixing/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers=[
2222
"Programming Language :: Python :: 3.11",
2323
]
2424

25-
dependencies = ["datasets"]
25+
dependencies = ["datasets", "torchdata"]
2626

2727
[tool.hatch.build.targets.wheel]
2828
only-include = ["src/fms_acceleration_odm"]

plugins/online-data-mixing/src/fms_acceleration_odm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
# Local
17+
from .callback import DataloaderSavingCallback
1718
from .framework_plugin_odm import OnlineDataMixingAccelerationPlugin
1819
from .odm import OnlineMixingDataset, Reward, compute_reward
1920
from .patch import patch_hf_trainer_evaluate
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# fms-hf-tuning patch
2+
# Standard
3+
from logging import getLogger
4+
import os
5+
6+
# Third Party
7+
from transformers import TrainerCallback
8+
import torch
9+
10+
logger = getLogger(__name__)
11+
12+
13+
class DataloaderSavingCallback(TrainerCallback):
14+
def __init__(self, accelerator):
15+
super().__init__()
16+
self.accelerator = accelerator
17+
18+
def on_save(self, args, state, control, **kwargs):
19+
if not self.accelerator.is_main_process:
20+
return
21+
# Third Party
22+
# pylint: disable=import-outside-toplevel
23+
from torchdata.stateful_dataloader import StatefulDataLoader
24+
25+
checkpoint_path = os.path.join(
26+
args.output_dir, f"checkpoint-{state.global_step}"
27+
)
28+
# It is assumed that one of the datasets would be stateful
29+
# if stateful then it would be training dataset
30+
for i, _ in enumerate(self.accelerator._dataloaders):
31+
if isinstance(
32+
self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader
33+
):
34+
torch.save(
35+
self.accelerator._dataloaders[i].state_dict(),
36+
os.path.join(checkpoint_path, "odm_dl_state_dict.bin"),
37+
)
38+
break

plugins/online-data-mixing/src/fms_acceleration_odm/framework_plugin_odm.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323

2424
# Local
25+
from .callback import DataloaderSavingCallback
2526
from .patch import patch_hf_trainer_evaluate
2627

2728

@@ -36,6 +37,11 @@ def __init__(self, configurations: Dict[str, Dict]):
3637
default=1,
3738
)
3839

40+
self._resume_from_checkpoint = self._check_config_and_maybe_check_values(
41+
key="training.odm.odm.resume_from_checkpoint",
42+
default=False,
43+
)
44+
3945
# data_config file should be there
4046
@property
4147
def requires_augmentation(self):
@@ -55,15 +61,18 @@ def augmentation(
5561
train_args.eval_steps = 1
5662
train_args.eval_strategy = "steps"
5763

58-
# update_interval information has to be made available in the evaluate HF patch
59-
# function and this seems to be the only reasonable way to do so
64+
# update_interval and resume_from_checkpoint information has to be made
65+
# available in the evaluate HF patch function and this seems to be
66+
# the only reasonable way to do so
6067
model.ta_update_interval = self._update_interval
68+
model.resume_from_checkpoint = self._resume_from_checkpoint
69+
6170
return model, modifiable_args
6271

6372
def get_callbacks_and_ready_for_train(
6473
self, model: torch.nn.Module = None, accelerator=None
6574
):
66-
callbacks = []
75+
callbacks = [DataloaderSavingCallback(accelerator)]
6776
patch_hf_trainer_evaluate()
6877
return callbacks
6978

plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Third Party
1010
from datasets import DatasetDict
1111
from torch.utils.data import DataLoader, IterableDataset
12+
from torchdata.stateful_dataloader import StatefulDataLoader
1213
from tqdm import tqdm
1314
import torch
1415

@@ -97,20 +98,23 @@ def __init__(
9798
self.eval_collators_dict = eval_collators_dict
9899
self.eval_dataset_dict = eval_dataset_dict
99100
self.eval_dataset_dict_dl = {}
101+
# iterators of the dataloaders
102+
self.train_dataset_dict_dl_iter = {}
103+
# to reset iterators to dataloaders
100104
self.train_dataset_dict_dl = {}
105+
self.dataset_dict = dataset_dict
101106
# prepare torch dataloaders for each of the dataset.
102-
for k, _ in dataset_dict.items():
103-
dataset_dict[k] = DataLoader(
104-
dataset_dict[k],
107+
for k, _ in self.dataset_dict.items():
108+
self.train_dataset_dict_dl[k] = StatefulDataLoader(
109+
self.dataset_dict[k],
105110
1,
106111
shuffle=False,
107-
num_workers=1,
112+
num_workers=0,
108113
collate_fn=collators_dict[k] if collators_dict else None,
109114
)
110-
self.train_dataset_dict_dl[k] = iter(dataset_dict[k])
115+
self.train_dataset_dict_dl_iter[k] = iter(self.train_dataset_dict_dl[k])
111116
self.eval_batch_size = eval_batch_size
112-
self.dataset_dict = dataset_dict
113-
self.category_list = sorted(self.train_dataset_dict_dl.keys())
117+
self.category_list = sorted(self.train_dataset_dict_dl_iter.keys())
114118
self.id2cat = dict(enumerate(self.category_list))
115119
self.cat2id = {c: i for i, c in enumerate(self.category_list)}
116120
self.total_categories = len(self.category_list)
@@ -172,7 +176,6 @@ def log_to_file(self, data: dict):
172176
f.write(json.dumps(self.log) + "\n")
173177

174178
def __iter__(self):
175-
self.produced = 0
176179
return self
177180

178181
def __next__(self):
@@ -182,17 +185,17 @@ def __next__(self):
182185
)[0]
183186
sample = None
184187
try:
185-
sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]])
188+
sample = next(self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]])
186189
except StopIteration:
187190
logger.info(
188191
"{id} dataset exhausted so the iterator is reset.".format(
189192
id=self.id2cat[self.arm_idx]
190193
)
191194
)
192-
self.train_dataset_dict_dl[self.id2cat[self.arm_idx]] = iter(
193-
self.dataset_dict[self.id2cat[self.arm_idx]]
195+
self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]] = iter(
196+
self.train_dataset_dict_dl[self.id2cat[self.arm_idx]]
194197
)
195-
sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]])
198+
sample = next(self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]])
196199

197200
self.curr_cat_count[self.arm_idx] += 1
198201
self.produced += 1
@@ -231,6 +234,44 @@ def __next__(self):
231234
)
232235
return sample
233236

237+
def load_state_dict(self, state_dict):
238+
"""Load the dataloader with the provided state dict"""
239+
torch.set_rng_state(state_dict["rng"])
240+
train_dataset_dict_dl_sd = state_dict.pop("train_dataset_dict_dl_sd")
241+
random.setstate(state_dict.pop("random_state"))
242+
for k, v in state_dict.items():
243+
if hasattr(self, k):
244+
setattr(self, k, v)
245+
self.reward_type = Reward[state_dict["reward_type"].upper()]
246+
for k, _ in train_dataset_dict_dl_sd.items():
247+
self.train_dataset_dict_dl_iter[k].load_state_dict(
248+
train_dataset_dict_dl_sd[k]
249+
)
250+
251+
def state_dict(self):
252+
"""Populate all the state that has to be stored by the stateful dataloader"""
253+
return {
254+
"rng": torch.get_rng_state(),
255+
"gamma": self.gamma,
256+
"eta": self.eta,
257+
"sampling_interval": self.sampling_interval,
258+
"train_dataset_dict_dl_sd": {
259+
k: v.state_dict() for k, v in self.train_dataset_dict_dl_iter.items()
260+
},
261+
"eval_batch_size": self.eval_batch_size,
262+
"category_list": self.category_list,
263+
"id2cat": self.id2cat,
264+
"cat2id": self.cat2id,
265+
"total_categories": self.total_categories,
266+
"sampling_weights": self.sampling_weights,
267+
"sampling_ratio": self.sampling_ratio,
268+
"curr_cat_count": self.curr_cat_count,
269+
"produced": self.produced,
270+
"arm_idx": self.arm_idx,
271+
"reward_type": str(self.reward_type),
272+
"random_state": random.getstate(),
273+
}
274+
234275
def _reset_eval_dataloaders(self):
235276
"""Helper function to reset eval dataloaders since
236277
they would be exhausted in the previous evaluation loop.
@@ -244,8 +285,8 @@ def _reset_eval_dataloaders(self):
244285
DataLoader(
245286
self.eval_dataset_dict[k],
246287
self.eval_batch_size,
247-
shuffle=False,
248-
num_workers=1,
288+
shuffle=True,
289+
num_workers=0,
249290
collate_fn=(
250291
self.eval_collators_dict[k]
251292
if self.eval_collators_dict
@@ -398,14 +439,14 @@ def update_sampling_weights(self, model, accelerator, state):
398439
if accelerator:
399440
rewards = accelerator.reduce(rewards, reduction="sum")
400441
count = accelerator.reduce(count, reduction="sum")
401-
if accelerator.is_main_process:
442+
if accelerator and accelerator.is_main_process:
402443
self._update_weights(count, rewards)
403-
self.log_to_file(
404-
{
405-
"current_sampling_weights": self.sampling_weights.tolist(),
406-
"current_sampling_ratio": self.sampling_ratio,
407-
"rewards": rewards.tolist(),
408-
"count": count.tolist(),
409-
"action": "update",
410-
}
411-
)
444+
self.log_to_file(
445+
{
446+
"current_sampling_weights": self.sampling_weights.tolist(),
447+
"current_sampling_ratio": self.sampling_ratio,
448+
"rewards": rewards.tolist(),
449+
"count": count.tolist(),
450+
"action": "update",
451+
}
452+
)

0 commit comments

Comments
 (0)