Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
a4b6309
feat: resume functionality
kmehant Sep 29, 2025
793f2eb
feat: resume functionality
kmehant Sep 29, 2025
1e2f21a
feat: resume functionality
kmehant Sep 29, 2025
3a8aebc
feat: resume functionality
kmehant Sep 29, 2025
a57c776
feat: resume functionality
kmehant Sep 29, 2025
9faab0f
feat: resume functionality
kmehant Sep 29, 2025
0cacc0c
feat: resume functionality
kmehant Sep 29, 2025
88c5c55
feat: resume functionality
kmehant Sep 29, 2025
450a670
feat: resume functionality
kmehant Sep 29, 2025
60a83c6
feat: resume functionality
kmehant Sep 29, 2025
eff5a47
feat: resume functionality
kmehant Sep 29, 2025
96eba03
feat: resume functionality
kmehant Sep 29, 2025
1560cc8
feat: resume functionality
kmehant Sep 29, 2025
875b31f
feat: resume functionality
kmehant Sep 30, 2025
9b747c3
feat: resume functionality
kmehant Sep 30, 2025
de5e456
feat: resume functionality
kmehant Sep 30, 2025
cb312b6
feat: resume functionality
kmehant Sep 30, 2025
8ea93b1
feat: resume functionality
kmehant Sep 30, 2025
e76ac4f
feat: resume functionality
kmehant Sep 30, 2025
327e301
feat: resume functionality
kmehant Sep 30, 2025
4c5fffc
feat: resume functionality
kmehant Sep 30, 2025
b55d7e7
feat: resume functionality
kmehant Sep 30, 2025
d768755
feat: resume functionality
kmehant Sep 30, 2025
a9576bd
feat: resume functionality
kmehant Sep 30, 2025
507547f
feat: resume functionality
kmehant Sep 30, 2025
537e5a4
feat: resume functionalityclear
kmehant Sep 30, 2025
9c126dd
feat: resume functionalityclear
kmehant Sep 30, 2025
37a8c2a
feat: resume functionalityclear
kmehant Sep 30, 2025
b01f022
feat: resume functionalityclear
kmehant Sep 30, 2025
79a2544
feat: resume functionalityclear
kmehant Sep 30, 2025
01f9ad5
feat: resume functionalityclear
kmehant Sep 30, 2025
177ccfd
feat: resume functionalityclear
kmehant Sep 30, 2025
dc0d018
feat: resume functionalityclear
kmehant Sep 30, 2025
6b3a9e3
feat: resume functionalityclear
kmehant Sep 30, 2025
bc704cc
feat: resume functionalityclear
kmehant Sep 30, 2025
d9ed4b0
feat: resume functionality
kmehant Sep 30, 2025
a07cb6e
feat: resume functionality
kmehant Sep 30, 2025
704edb1
feat: resume functionality
kmehant Sep 30, 2025
5f10600
feat: resume functionality
kmehant Sep 30, 2025
94012ce
feat: resume functionality
kmehant Sep 30, 2025
6bbd52b
feat: resume functionality
kmehant Sep 30, 2025
ec6a312
feat: resume functionality
kmehant Sep 30, 2025
21c3e50
feat: resume functionality
kmehant Sep 30, 2025
d7f67e9
feat: resume functionality
kmehant Sep 30, 2025
f1b3a9d
feat: resume functionality
kmehant Sep 30, 2025
d9b24e2
feat: resume functionality
kmehant Sep 30, 2025
925a070
feat: resume functionality
kmehant Sep 30, 2025
3f80532
feat: resume functionality
kmehant Sep 30, 2025
d5b8534
feat: resume functionality
kmehant Sep 30, 2025
63e2f0a
feat: resume functionality
kmehant Sep 30, 2025
1f3f6df
feat: resume functionality
kmehant Oct 6, 2025
ca279aa
feat: resume functionality
kmehant Oct 6, 2025
8ad95bb
feat: resume functionality
kmehant Oct 6, 2025
6abb3e6
feat: resume functionality
kmehant Oct 6, 2025
a4b533b
feat: resume functionality
kmehant Oct 6, 2025
ac89e0f
feat: resume functionality
kmehant Oct 6, 2025
7bce87f
feat: resume functionality
kmehant Oct 6, 2025
efdf799
feat: resume functionality
kmehant Oct 7, 2025
9c3935f
feat: resume functionality
kmehant Oct 7, 2025
4ca691f
feat: resume functionality
kmehant Oct 7, 2025
ea3c515
feat: resume functionality
kmehant Oct 7, 2025
ecc2d62
feat: resume functionality
kmehant Oct 7, 2025
f5a4e81
feat: resume functionality
kmehant Oct 7, 2025
fe9d848
feat: resume functionality
kmehant Oct 7, 2025
20fac59
feat: resume functionality
kmehant Oct 7, 2025
79d0b0f
feat: resume functionality
kmehant Oct 7, 2025
e610899
feat: resume functionality
kmehant Oct 7, 2025
3fe2eb4
feat: resume functionality
kmehant Oct 7, 2025
3d8622e
feat: resume functionality
kmehant Oct 7, 2025
00ce95c
feat: resume functionality
kmehant Oct 7, 2025
415a814
feat: resume functionality
kmehant Oct 7, 2025
a28e559
feat: resume functionality
kmehant Oct 7, 2025
3d8bf68
feat: resume functionality
kmehant Oct 7, 2025
9f48263
feat: resume functionality
kmehant Oct 7, 2025
4f1381f
feat: resume functionality
kmehant Oct 7, 2025
890dd61
feat: resume functionality
kmehant Oct 7, 2025
8362658
feat: resume functionality
kmehant Oct 7, 2025
f822d6f
feat: resume functionality
kmehant Oct 7, 2025
f94740e
feat: resume functionality
kmehant Oct 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/online-data-mixing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers=[
"Programming Language :: Python :: 3.11",
]

dependencies = ["datasets"]
dependencies = ["datasets", "torchdata"]

[tool.hatch.build.targets.wheel]
only-include = ["src/fms_acceleration_odm"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


# Local
from .callback import DataloaderSavingCallback
from .framework_plugin_odm import OnlineDataMixingAccelerationPlugin
from .odm import OnlineMixingDataset, Reward, compute_reward
from .patch import patch_hf_trainer_evaluate
38 changes: 38 additions & 0 deletions plugins/online-data-mixing/src/fms_acceleration_odm/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# fms-hf-tuning patch
# Standard
from logging import getLogger
import os

# Third Party
from transformers import TrainerCallback
import torch

logger = getLogger(__name__)


class DataloaderSavingCallback(TrainerCallback):
def __init__(self, accelerator):
super().__init__()
self.accelerator = accelerator

def on_save(self, args, state, control, **kwargs):
if not self.accelerator.is_main_process:
return
# Third Party
# pylint: disable=import-outside-toplevel
from torchdata.stateful_dataloader import StatefulDataLoader

checkpoint_path = os.path.join(
args.output_dir, f"checkpoint-{state.global_step}"
)
# It is assumed that one of the datasets would be stateful
# if stateful then it would be training dataset
for i, _ in enumerate(self.accelerator._dataloaders):
if isinstance(
self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader
):
torch.save(
self.accelerator._dataloaders[i].state_dict(),
os.path.join(checkpoint_path, "odm_dl_state_dict.bin"),
)
break
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch

# Local
from .callback import DataloaderSavingCallback
from .patch import patch_hf_trainer_evaluate


Expand All @@ -36,6 +37,11 @@ def __init__(self, configurations: Dict[str, Dict]):
default=1,
)

self._resume_from_checkpoint = self._check_config_and_maybe_check_values(
key="training.odm.odm.resume_from_checkpoint",
default=False,
)

# data_config file should be there
@property
def requires_augmentation(self):
Expand All @@ -55,15 +61,18 @@ def augmentation(
train_args.eval_steps = 1
train_args.eval_strategy = "steps"

# update_interval information has to be made available in the evaluate HF patch
# function and this seems to be the only reasonable way to do so
# update_interval and resume_from_checkpoint information has to be made
# available in the evaluate HF patch function and this seems to be
# the only reasonable way to do so
model.ta_update_interval = self._update_interval
model.resume_from_checkpoint = self._resume_from_checkpoint

return model, modifiable_args

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator=None
):
callbacks = []
callbacks = [DataloaderSavingCallback(accelerator)]
patch_hf_trainer_evaluate()
return callbacks

Expand Down
89 changes: 65 additions & 24 deletions plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Third Party
from datasets import DatasetDict
from torch.utils.data import DataLoader, IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
import torch

Expand Down Expand Up @@ -97,20 +98,23 @@ def __init__(
self.eval_collators_dict = eval_collators_dict
self.eval_dataset_dict = eval_dataset_dict
self.eval_dataset_dict_dl = {}
# iterators of the dataloaders
self.train_dataset_dict_dl_iter = {}
# to reset iterators to dataloaders
self.train_dataset_dict_dl = {}
self.dataset_dict = dataset_dict
# prepare torch dataloaders for each of the dataset.
for k, _ in dataset_dict.items():
dataset_dict[k] = DataLoader(
dataset_dict[k],
for k, _ in self.dataset_dict.items():
self.train_dataset_dict_dl[k] = StatefulDataLoader(
self.dataset_dict[k],
1,
shuffle=False,
num_workers=1,
num_workers=0,
collate_fn=collators_dict[k] if collators_dict else None,
)
self.train_dataset_dict_dl[k] = iter(dataset_dict[k])
self.train_dataset_dict_dl_iter[k] = iter(self.train_dataset_dict_dl[k])
self.eval_batch_size = eval_batch_size
self.dataset_dict = dataset_dict
self.category_list = sorted(self.train_dataset_dict_dl.keys())
self.category_list = sorted(self.train_dataset_dict_dl_iter.keys())
self.id2cat = dict(enumerate(self.category_list))
self.cat2id = {c: i for i, c in enumerate(self.category_list)}
self.total_categories = len(self.category_list)
Expand Down Expand Up @@ -172,7 +176,6 @@ def log_to_file(self, data: dict):
f.write(json.dumps(self.log) + "\n")

def __iter__(self):
self.produced = 0
return self

def __next__(self):
Expand All @@ -182,17 +185,17 @@ def __next__(self):
)[0]
sample = None
try:
sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]])
sample = next(self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]])
except StopIteration:
logger.info(
"{id} dataset exhausted so the iterator is reset.".format(
id=self.id2cat[self.arm_idx]
)
)
self.train_dataset_dict_dl[self.id2cat[self.arm_idx]] = iter(
self.dataset_dict[self.id2cat[self.arm_idx]]
self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]] = iter(
self.train_dataset_dict_dl[self.id2cat[self.arm_idx]]
)
sample = next(self.train_dataset_dict_dl[self.id2cat[self.arm_idx]])
sample = next(self.train_dataset_dict_dl_iter[self.id2cat[self.arm_idx]])

self.curr_cat_count[self.arm_idx] += 1
self.produced += 1
Expand Down Expand Up @@ -231,6 +234,44 @@ def __next__(self):
)
return sample

def load_state_dict(self, state_dict):
"""Load the dataloader with the provided state dict"""
torch.set_rng_state(state_dict["rng"])
train_dataset_dict_dl_sd = state_dict.pop("train_dataset_dict_dl_sd")
random.setstate(state_dict.pop("random_state"))
for k, v in state_dict.items():
if hasattr(self, k):
setattr(self, k, v)
self.reward_type = Reward[state_dict["reward_type"].upper()]
for k, _ in train_dataset_dict_dl_sd.items():
self.train_dataset_dict_dl_iter[k].load_state_dict(
train_dataset_dict_dl_sd[k]
)

def state_dict(self):
"""Populate all the state that has to be stored by the stateful dataloader"""
return {
"rng": torch.get_rng_state(),
"gamma": self.gamma,
"eta": self.eta,
"sampling_interval": self.sampling_interval,
"train_dataset_dict_dl_sd": {
k: v.state_dict() for k, v in self.train_dataset_dict_dl_iter.items()
},
"eval_batch_size": self.eval_batch_size,
"category_list": self.category_list,
"id2cat": self.id2cat,
"cat2id": self.cat2id,
"total_categories": self.total_categories,
"sampling_weights": self.sampling_weights,
"sampling_ratio": self.sampling_ratio,
"curr_cat_count": self.curr_cat_count,
"produced": self.produced,
"arm_idx": self.arm_idx,
"reward_type": str(self.reward_type),
"random_state": random.getstate(),
}

def _reset_eval_dataloaders(self):
"""Helper function to reset eval dataloaders since
they would be exhausted in the previous evaluation loop.
Expand All @@ -244,8 +285,8 @@ def _reset_eval_dataloaders(self):
DataLoader(
self.eval_dataset_dict[k],
self.eval_batch_size,
shuffle=False,
num_workers=1,
shuffle=True,
num_workers=0,
collate_fn=(
self.eval_collators_dict[k]
if self.eval_collators_dict
Expand Down Expand Up @@ -398,14 +439,14 @@ def update_sampling_weights(self, model, accelerator, state):
if accelerator:
rewards = accelerator.reduce(rewards, reduction="sum")
count = accelerator.reduce(count, reduction="sum")
if accelerator.is_main_process:
if accelerator and accelerator.is_main_process:
self._update_weights(count, rewards)
self.log_to_file(
{
"current_sampling_weights": self.sampling_weights.tolist(),
"current_sampling_ratio": self.sampling_ratio,
"rewards": rewards.tolist(),
"count": count.tolist(),
"action": "update",
}
)
self.log_to_file(
{
"current_sampling_weights": self.sampling_weights.tolist(),
"current_sampling_ratio": self.sampling_ratio,
"rewards": rewards.tolist(),
"count": count.tolist(),
"action": "update",
}
)
Loading