Skip to content

Commit 665219d

Browse files
committed
add assert and revert some missing comments
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 0428ef1 commit 665219d

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

examples/run_grpo.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ def main() -> None:
139139
"use_multiple_dataloader is not supported with async GRPO"
140140
)
141141

142+
# Async GDPO is not supported
143+
if config["grpo"]["adv_estimator"]["name"] == "gdpo":
144+
raise NotImplementedError(
145+
"GDPO is not supported for async training, "
146+
"please set grpo.async_grpo.enabled to false in your config."
147+
)
148+
142149
from nemo_rl.algorithms.grpo import async_grpo_train
143150

144151
print("🚀 Running async GRPO training")

nemo_rl/algorithms/grpo.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,10 @@ def _create_advantage_estimator(master_config: MasterConfig):
10461046
"""
10471047
grpo_config = master_config["grpo"]
10481048
loss_config = master_config["loss_fn"]
1049+
10491050
# Provide backward-compatible defaults when adv_estimator is not in config.
1051+
# Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline
1052+
# which older configs still use.
10501053
adv_estimator_config = grpo_config.get(
10511054
"adv_estimator",
10521055
{
@@ -1061,6 +1064,10 @@ def _create_advantage_estimator(master_config: MasterConfig):
10611064

10621065
adv_estimator_name = adv_estimator_config["name"]
10631066
if adv_estimator_name == "gdpo":
1067+
assert not _should_use_async_rollouts(master_config), (
1068+
"GDPO is not supported for async rollouts, "
1069+
"please set policy.generation.vllm_cfg.async_engine to false in your config."
1070+
)
10641071
adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config)
10651072
print(" ✓ Using GDPO advantage estimator (multi-reward)")
10661073
elif adv_estimator_name == "grpo":
@@ -1372,10 +1379,9 @@ def grpo_train(
13721379
val_period = master_config["grpo"]["val_period"]
13731380
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]
13741381

1375-
# Create advantage estimator
1382+
# Initialize advantage estimator
13761383
adv_estimator = _create_advantage_estimator(master_config)
13771384

1378-
13791385
# Run validation at the start if configured
13801386
# TODO: Add validation with kv scales if needed
13811387
if val_at_start and current_step == 0:
@@ -1596,8 +1602,8 @@ def grpo_train(
15961602
# Calculate rewards & advantages
15971603
memory_tracker.snapshot_start_of_stage("Processing rewards", dir())
15981604
print("▶ Processing rewards...,", flush=True)
1599-
# GDPO
16001605
with timer.time("reward_calculation"):
1606+
# Extract rewards from final_batch
16011607
rewards = repeated_batch["total_reward"]
16021608
# Store input_ids in batch so that after dynamic_sampling it stays aligned with
16031609
# the (possibly filtered) batch: select_indices / from_batches / slice all
@@ -1788,8 +1794,6 @@ def grpo_train(
17881794
sample_mask = train_data["sample_mask"]
17891795
mask = token_mask * sample_mask.unsqueeze(-1)
17901796

1791-
1792-
17931797
train_data["advantages"] = adv_estimator.compute_advantage(
17941798
repeated_batch=repeated_batch,
17951799
mask=mask,
@@ -2445,7 +2449,7 @@ def async_grpo_train(
24452449
val_at_end = master_config["grpo"]["val_at_end"]
24462450
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]
24472451

2448-
# Create advantage estimator
2452+
# Initialize advantage estimator
24492453
adv_estimator = _create_advantage_estimator(master_config)
24502454

24512455
assert not colocated_inference, (

0 commit comments

Comments
 (0)