Skip to content

Commit ecf4918

Browse files
authored
fix: ODM support for padding free collator (#165)
* Converting sample labels to list Signed-off-by: romitjain <[email protected]> * Deterministic sampling Signed-off-by: romitjain <[email protected]> * Moved update weights out of main process update Signed-off-by: romitjain <[email protected]> --------- Signed-off-by: romitjain <[email protected]>
1 parent d451073 commit ecf4918

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import math
2626

2727
# Third Party
28-
import torch
2928
from datasets import Dataset, DatasetDict
3029
from sentence_transformers import SentenceTransformer
3130
from sklearn.cluster import KMeans
3231
import numpy as np
32+
import torch
3333

3434
logger = getLogger(__name__)
3535

plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
output_dir="odm",
3737
reward_type=Reward.ENTROPY,
3838
auto_categorize_config: Optional[dict | AutoCategorizeConfig] = None,
39+
seed: Optional[int] = 42,
3940
):
4041
"""Mixes datasets with sampling ratios learnt using
4142
Multi Armed Bandit (MAB) EXP3 and rewards defined.
@@ -69,6 +70,8 @@ def __init__(
6970
configuration overrides for the auto-categorizer such as text column,
7071
embedding model, cluster count etc. This will only be used if the `dataset_dict`
7172
has only one key.
73+
seed (Optional[int], optional): Base seed for the dataset-level RNG so all
74+
distributed ranks iterate over the exact same sample order. Defaults to 42.
7275
"""
7376
self.auto_categorize = len(dataset_dict.keys()) == 1
7477
self._auto_categorize_config = self._build_auto_categorize_config(
@@ -190,6 +193,12 @@ def __init__(
190193
"action": "", # one of sample or update
191194
}
192195

196+
# Local RNG so every process can deterministically sample identical streams.
197+
self.seed = 42 if seed is None else seed
198+
self._rng = random.Random(self.seed)
199+
self._current_epoch = 0
200+
self._rng_state_restored = False
201+
193202
def log_to_file(self, data: dict):
194203
"""helper function to log the state to the file
195204
@@ -203,9 +212,17 @@ def log_to_file(self, data: dict):
203212
def __iter__(self):
204213
return self
205214

215+
def set_epoch(self, epoch: int):
216+
"""Ensures every process observes the same RNG state per epoch."""
217+
self._current_epoch = epoch
218+
if self._rng_state_restored:
219+
self._rng_state_restored = False
220+
return
221+
self._rng.seed(self.seed + epoch)
222+
206223
def __next__(self):
207224
if self.produced % self.sampling_interval == 0:
208-
self.arm_idx = random.choices(
225+
self.arm_idx = self._rng.choices(
209226
range(self.total_categories), weights=self.sampling_ratio, k=1
210227
)[0]
211228
sample = None
@@ -243,7 +260,7 @@ def __next__(self):
243260
else torch.ones_like(sample["input_ids"][0])
244261
),
245262
"labels": (
246-
sample["labels"][0]
263+
sample["labels"][0].tolist()
247264
if "labels" in sample
248265
else sample["input_ids"][0]
249266
),
@@ -264,6 +281,16 @@ def load_state_dict(self, state_dict):
264281
torch.set_rng_state(state_dict["rng"])
265282
train_dataset_dict_dl_sd = state_dict.pop("train_dataset_dict_dl_sd")
266283
random.setstate(state_dict.pop("random_state"))
284+
dataset_rng_state = state_dict.pop("online_mixing_rng_state", None)
285+
saved_seed = state_dict.pop("seed", None)
286+
saved_epoch = state_dict.pop("_current_epoch", None)
287+
if saved_seed is not None:
288+
self.seed = saved_seed
289+
if saved_epoch is not None:
290+
self._current_epoch = saved_epoch
291+
if dataset_rng_state is not None:
292+
self._rng.setstate(dataset_rng_state)
293+
self._rng_state_restored = True
267294
for k, v in state_dict.items():
268295
if hasattr(self, k):
269296
setattr(self, k, v)
@@ -295,6 +322,9 @@ def state_dict(self):
295322
"arm_idx": self.arm_idx,
296323
"reward_type": str(self.reward_type),
297324
"random_state": random.getstate(),
325+
"online_mixing_rng_state": self._rng.getstate(),
326+
"seed": self.seed,
327+
"_current_epoch": self._current_epoch,
298328
}
299329

300330
def _reset_eval_dataloaders(self):
@@ -516,8 +546,9 @@ def update_sampling_weights(self, model, accelerator, state):
516546
if accelerator:
517547
rewards = accelerator.reduce(rewards, reduction="sum")
518548
count = accelerator.reduce(count, reduction="sum")
549+
550+
self._update_weights(count, rewards)
519551
if accelerator and accelerator.is_main_process:
520-
self._update_weights(count, rewards)
521552
self.log_to_file(
522553
{
523554
"current_sampling_weights": self.sampling_weights.tolist(),

0 commit comments

Comments
 (0)