Skip to content

Commit 905a224

Browse files
authored
fix: fix mcore train_iters in grpo (#1383)
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 7bd853a commit 905a224

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,10 @@ def setup(
467467

468468
if policy_config.get("megatron_cfg", {}).get("enabled", False):
469469
## NOTE: this is equal to the total number of scheduler steps
470-
total_train_iters = min(grpo_config["max_num_steps"], len(dataloader))
470+
total_train_iters = min(
471+
grpo_config["max_num_steps"],
472+
grpo_config["max_num_epochs"] * len(dataloader),
473+
)
471474
policy_config["megatron_cfg"]["train_iters"] = total_train_iters
472475

473476
policy = Policy(

0 commit comments

Comments
 (0)