diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1dbef0b76..d5208c4f8 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -322,7 +322,7 @@ async def main(cfg: DictConfig): DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), RLTrainer.options(**cfg.actors.trainer).as_actor( - **cfg.trainer, loss=simple_grpo_loss + **cfg.trainer, loss=simple_grpo_loss, step=cfg.trainer.checkpoint.load_step ), ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( **cfg.replay_buffer, collate=collate @@ -353,6 +353,14 @@ async def main(cfg: DictConfig): ) print("Torchstore successfully initialized with local rank strategy") + start_version = max(cfg.trainer.checkpoint.load_step, 0) + if start_version > 0: + # Ensure the trainer’s loaded checkpoint is pushed to torchstore at `start_version` + await trainer.push_weights.call(start_version) + + # Warm the policy to that exact version so new rollouts carry generator_version == start_version + await policy.update_weights.fanout(start_version) + # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 @@ -420,7 +428,7 @@ async def continuous_rollouts(): t.stop() async def continuous_training(): - training_step = 0 + training_step = start_version restart_tracer = True # Flag to control when to restart tracer while max_steps == -1 or training_step < max_steps: diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 14e4871cf..748225510 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -74,8 +74,10 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: hf://${model} - initial_load_in_hf: true + folder: ./checkpoint # Directory to save or resume checkpoints (default: ./checkpoints) + load_step: -1 # Step to load from; cannot be hf ckpt; -1 means load from initial_load_path. (default: -1) + initial_load_path: hf://${model} # Optional: path or HF identifier to load model weights initially, will be ignored if `folder` exists + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 async_mode: "disabled" diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index dd85b3c82..d730b5f3e 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -136,6 +136,7 @@ class RLTrainer(ForgeActor): TORCHSTORE_USE_RDMA.get_value() == 0 ) # torchstore currently only accepts 0 or 1 dcp_path: str = "forge_dcp_tmp" + step: int = 1 def __post_init__(self): """Initializes config types and env variables. @@ -159,8 +160,9 @@ def __post_init__(self): raise TypeError( f"{f.name} should be a {f.type} type or a dict like object" ) - - self.step = 1 # fragile contract. + self.step = max( + self.step, 1 + ) # start from 1 if not loading from a saved checkpoint self.num_training_steps = self.training.steps self.gradient_accumulation_steps = 1 self.rank = current_rank().rank @@ -186,12 +188,7 @@ def __post_init__(self): async def setup(self): # TODO: update ForgeEngine to not use ForgeJobConfig engine_config = {f.name: getattr(self, f.name) for f in fields(self)} - for key in { - "loss", - "state_dict_key", - "use_dcp", - "dcp_path", - }: + for key in {"loss", "state_dict_key", "use_dcp", "dcp_path", "step"}: engine_config.pop(key) # Not part of job config self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) self.engine.checkpointer.load(step=self.step)