Skip to content

Commit 4af5ccb

Browse files
casteryhDNXie
authored andcommitted
Fix: Enable multi-epoch training by restarting dataset iterator (meta-pytorch#519)
1 parent 6087b09 commit 4af5ccb

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

apps/grpo/main.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class DatasetActor(ForgeActor):
245245
@endpoint
246246
def setup(self):
247247
self._tokenizer = get_tokenizer(self.model)
248+
self._epoch = 0
248249

249250
def gsm8k_transform(sample):
250251
system_prompt = """
@@ -265,12 +266,12 @@ def gsm8k_transform(sample):
265266
formatted_target = target.split("#### ")[1]
266267
return {"request": formatted_request, "target": formatted_target}
267268

268-
ds = load_dataset(
269+
self._base_dataset = load_dataset(
269270
self.path, self.revision, split=self.data_split, streaming=self.streaming
270271
)
271-
ds = ds.map(gsm8k_transform)
272-
ds = ds.shuffle()
273-
self._iterator = iter(ds)
272+
self._base_dataset = self._base_dataset.map(gsm8k_transform)
273+
self._base_dataset = self._base_dataset.shuffle()
274+
self._iterator = iter(self._base_dataset)
274275

275276
@endpoint
276277
async def sample(self) -> dict[str, str] | None:
@@ -283,10 +284,18 @@ async def sample(self) -> dict[str, str] | None:
283284
len(sample["request"]),
284285
Reduce.MEAN,
285286
)
287+
record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)
286288

287289
return sample
288290
except StopIteration:
289-
return None
291+
# Restart iterator for next epoch with reshuffling
292+
self._epoch += 1
293+
print(
294+
f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}"
295+
)
296+
self._base_dataset.set_epoch(self._epoch)
297+
self._iterator = iter(self._base_dataset)
298+
return next(self._iterator)
290299

291300
@endpoint
292301
async def pad_token(self):

0 commit comments

Comments
 (0)