@@ -245,6 +245,7 @@ class DatasetActor(ForgeActor):
245245 @endpoint
246246 def setup (self ):
247247 self ._tokenizer = get_tokenizer (self .model )
248+ self ._epoch = 0
248249
249250 def gsm8k_transform (sample ):
250251 system_prompt = """
@@ -265,12 +266,12 @@ def gsm8k_transform(sample):
265266 formatted_target = target .split ("#### " )[1 ]
266267 return {"request" : formatted_request , "target" : formatted_target }
267268
268- ds = load_dataset (
269+ self . _base_dataset = load_dataset (
269270 self .path , self .revision , split = self .data_split , streaming = self .streaming
270271 )
271- ds = ds .map (gsm8k_transform )
272- ds = ds .shuffle ()
273- self ._iterator = iter (ds )
272+ self . _base_dataset = self . _base_dataset .map (gsm8k_transform )
273+ self . _base_dataset = self . _base_dataset .shuffle ()
274+ self ._iterator = iter (self . _base_dataset )
274275
275276 @endpoint
276277 async def sample (self ) -> dict [str , str ] | None :
@@ -283,10 +284,18 @@ async def sample(self) -> dict[str, str] | None:
283284 len (sample ["request" ]),
284285 Reduce .MEAN ,
285286 )
287+ record_metric ("dataset/sample/current_epoch" , self ._epoch , Reduce .MAX )
286288
287289 return sample
288290 except StopIteration :
289- return None
291+ # Restart iterator for next epoch with reshuffling
292+ self ._epoch += 1
293+ print (
294+ f"Dataset epoch { self ._epoch - 1 } completed. Starting epoch { self ._epoch } "
295+ )
296+ self ._base_dataset .set_epoch (self ._epoch )
297+ self ._iterator = iter (self ._base_dataset )
298+ return next (self ._iterator )
290299
291300 @endpoint
292301 async def pad_token (self ):
0 commit comments