Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ async def main(cfg: DictConfig):
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
RLTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
**cfg.trainer, loss=simple_grpo_loss, step=cfg.trainer.checkpoint.load_step
),
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
**cfg.replay_buffer, collate=collate
Expand Down Expand Up @@ -353,6 +353,14 @@ async def main(cfg: DictConfig):
)
print("Torchstore successfully initialized with local rank strategy")

start_version = max(cfg.trainer.checkpoint.load_step, 0)
if start_version > 0:
# Ensure the trainer’s loaded checkpoint is pushed to torchstore at `start_version`
await trainer.push_weights.call(start_version)

# Warm the policy to that exact version so new rollouts carry generator_version == start_version
await policy.update_weights.fanout(start_version)

# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
Expand Down Expand Up @@ -420,7 +428,7 @@ async def continuous_rollouts():
t.stop()

async def continuous_training():
training_step = 0
training_step = start_version
restart_tracer = True # Flag to control when to restart tracer

while max_steps == -1 or training_step < max_steps:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ trainer:
disable_loss_parallel: true
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true
folder: ./checkpoint # Directory to save or resume checkpoints (default: ./checkpoints)
load_step: -1 # Step to load from; cannot be hf ckpt; -1 means load from initial_load_path. (default: -1)
initial_load_path: hf://${model} # Optional: path or HF identifier to load model weights initially, will be ignored if `folder` exists
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
Expand Down
13 changes: 5 additions & 8 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class RLTrainer(ForgeActor):
TORCHSTORE_USE_RDMA.get_value() == 0
) # torchstore currently only accepts 0 or 1
dcp_path: str = "forge_dcp_tmp"
step: int = 1

def __post_init__(self):
"""Initializes config types and env variables.
Expand All @@ -159,8 +160,9 @@ def __post_init__(self):
raise TypeError(
f"{f.name} should be a {f.type} type or a dict like object"
)

self.step = 1 # fragile contract.
self.step = max(
self.step, 1
) # start from 1 if not loading from a saved checkpoint
self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
Expand All @@ -186,12 +188,7 @@ def __post_init__(self):
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
for key in {
"loss",
"state_dict_key",
"use_dcp",
"dcp_path",
}:
for key in {"loss", "state_dict_key", "use_dcp", "dcp_path", "step"}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)
Expand Down
Loading