From 287dcd323745d247dbc8fdc3b21383f57980f51e Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Mon, 1 Dec 2025 07:51:29 -0800 Subject: [PATCH 1/2] upd --- apps/grpo/main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f0e06ebb8..fc08d9b5a 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -267,6 +267,7 @@ class DatasetActor(ForgeActor): data_split: str = "train" streaming: bool = True model: str = "Qwen/Qwen3-1.7B" + seed: int = 36 @endpoint async def setup(self): @@ -301,7 +302,8 @@ def gsm8k_transform(sample): self.path, self.revision, split=self.data_split, streaming=self.streaming ) self._base_dataset = self._base_dataset.map(gsm8k_transform) - self._base_dataset = self._base_dataset.shuffle() + self._base_dataset = self._base_dataset.shuffle(seed=self.seed) + self._base_dataset.set_epoch(self._epoch) # Set initial epoch for determinism self._iterator = iter(self._base_dataset) @endpoint @@ -329,7 +331,9 @@ async def sample(self) -> dict[str, str] | None: print( f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" ) - self._base_dataset.set_epoch(self._epoch) + self._base_dataset.set_epoch( + self._epoch + ) # Set epoch for deterministic iteration self._iterator = iter(self._base_dataset) return next(self._iterator) From cfb0ebdf8f242e766534baca956d77426f9d77fa Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Mon, 1 Dec 2025 07:54:35 -0800 Subject: [PATCH 2/2] upd --- apps/grpo/main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index fc08d9b5a..3e336add0 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -303,7 +303,7 @@ def gsm8k_transform(sample): ) self._base_dataset = self._base_dataset.map(gsm8k_transform) self._base_dataset = self._base_dataset.shuffle(seed=self.seed) - self._base_dataset.set_epoch(self._epoch) # Set initial epoch for determinism + self._base_dataset.set_epoch(self._epoch) # for determinism self._iterator = iter(self._base_dataset) @endpoint @@ -331,9 +331,7 @@ async def sample(self) -> dict[str, str] | None: print( f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}" ) - self._base_dataset.set_epoch( - self._epoch - ) # Set epoch for deterministic iteration + self._base_dataset.set_epoch(self._epoch) # for determinism self._iterator = iter(self._base_dataset) return next(self._iterator)