@@ -36,6 +36,7 @@ def __init__(
3636 output_dir = "odm" ,
3737 reward_type = Reward .ENTROPY ,
3838 auto_categorize_config : Optional [dict | AutoCategorizeConfig ] = None ,
39+ seed : Optional [int ] = 42 ,
3940 ):
4041 """Mixes datasets with sampling ratios learnt using
4142 Multi Armed Bandit (MAB) EXP3 and rewards defined.
@@ -69,6 +70,8 @@ def __init__(
6970 configuration overrides for the auto-categorizer such as text column,
7071 embedding model, cluster count etc. This will only be used if the `dataset_dict`
7172 has only one key.
73+ seed (Optional[int], optional): Base seed for the dataset-level RNG so all
74+ distributed ranks iterate over the exact same sample order. Defaults to 42.
7275 """
7376 self .auto_categorize = len (dataset_dict .keys ()) == 1
7477 self ._auto_categorize_config = self ._build_auto_categorize_config (
@@ -190,6 +193,12 @@ def __init__(
190193 "action" : "" , # one of sample or update
191194 }
192195
196+ # Local RNG so every process can deterministically sample identical streams.
197+ self .seed = 42 if seed is None else seed
198+ self ._rng = random .Random (self .seed )
199+ self ._current_epoch = 0
200+ self ._rng_state_restored = False
201+
193202 def log_to_file (self , data : dict ):
194203 """helper function to log the state to the file
195204
@@ -203,9 +212,17 @@ def log_to_file(self, data: dict):
203212 def __iter__ (self ):
204213 return self
205214
215+ def set_epoch (self , epoch : int ):
216+ """Ensures every process observes the same RNG state per epoch."""
217+ self ._current_epoch = epoch
218+ if self ._rng_state_restored :
219+ self ._rng_state_restored = False
220+ return
221+ self ._rng .seed (self .seed + epoch )
222+
206223 def __next__ (self ):
207224 if self .produced % self .sampling_interval == 0 :
208- self .arm_idx = random .choices (
225+ self .arm_idx = self . _rng .choices (
209226 range (self .total_categories ), weights = self .sampling_ratio , k = 1
210227 )[0 ]
211228 sample = None
@@ -243,7 +260,7 @@ def __next__(self):
243260 else torch .ones_like (sample ["input_ids" ][0 ])
244261 ),
245262 "labels" : (
246- sample ["labels" ][0 ]
263+ sample ["labels" ][0 ]. tolist ()
247264 if "labels" in sample
248265 else sample ["input_ids" ][0 ]
249266 ),
@@ -264,6 +281,16 @@ def load_state_dict(self, state_dict):
264281 torch .set_rng_state (state_dict ["rng" ])
265282 train_dataset_dict_dl_sd = state_dict .pop ("train_dataset_dict_dl_sd" )
266283 random .setstate (state_dict .pop ("random_state" ))
284+ dataset_rng_state = state_dict .pop ("online_mixing_rng_state" , None )
285+ saved_seed = state_dict .pop ("seed" , None )
286+ saved_epoch = state_dict .pop ("_current_epoch" , None )
287+ if saved_seed is not None :
288+ self .seed = saved_seed
289+ if saved_epoch is not None :
290+ self ._current_epoch = saved_epoch
291+ if dataset_rng_state is not None :
292+ self ._rng .setstate (dataset_rng_state )
293+ self ._rng_state_restored = True
267294 for k , v in state_dict .items ():
268295 if hasattr (self , k ):
269296 setattr (self , k , v )
@@ -295,6 +322,9 @@ def state_dict(self):
295322 "arm_idx" : self .arm_idx ,
296323 "reward_type" : str (self .reward_type ),
297324 "random_state" : random .getstate (),
325+ "online_mixing_rng_state" : self ._rng .getstate (),
326+ "seed" : self .seed ,
327+ "_current_epoch" : self ._current_epoch ,
298328 }
299329
300330 def _reset_eval_dataloaders (self ):
@@ -516,8 +546,9 @@ def update_sampling_weights(self, model, accelerator, state):
516546 if accelerator :
517547 rewards = accelerator .reduce (rewards , reduction = "sum" )
518548 count = accelerator .reduce (count , reduction = "sum" )
549+
550+ self ._update_weights (count , rewards )
519551 if accelerator and accelerator .is_main_process :
520- self ._update_weights (count , rewards )
521552 self .log_to_file (
522553 {
523554 "current_sampling_weights" : self .sampling_weights .tolist (),
0 commit comments