@@ -1046,7 +1046,10 @@ def _create_advantage_estimator(master_config: MasterConfig):
10461046 """
10471047 grpo_config = master_config ["grpo" ]
10481048 loss_config = master_config ["loss_fn" ]
1049+
10491050 # Provide backward-compatible defaults when adv_estimator is not in config.
1051+ # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline
1052+ # which older configs still use.
10501053 adv_estimator_config = grpo_config .get (
10511054 "adv_estimator" ,
10521055 {
@@ -1061,6 +1064,10 @@ def _create_advantage_estimator(master_config: MasterConfig):
10611064
10621065 adv_estimator_name = adv_estimator_config ["name" ]
10631066 if adv_estimator_name == "gdpo" :
1067+ assert not _should_use_async_rollouts (master_config ), (
1068+ "GDPO is not supported for async rollouts, "
1069+ "please set policy.generation.vllm_cfg.async_engine to false in your config."
1070+ )
10641071 adv_estimator = GDPOAdvantageEstimator (adv_estimator_config , loss_config )
10651072 print (" ✓ Using GDPO advantage estimator (multi-reward)" )
10661073 elif adv_estimator_name == "grpo" :
@@ -1372,10 +1379,9 @@ def grpo_train(
13721379 val_period = master_config ["grpo" ]["val_period" ]
13731380 colocated_inference = master_config ["policy" ]["generation" ]["colocated" ]["enabled" ]
13741381
1375- # Create advantage estimator
1382+ # Initialize advantage estimator
13761383 adv_estimator = _create_advantage_estimator (master_config )
13771384
1378-
13791385 # Run validation at the start if configured
13801386 # TODO: Add validation with kv scales if needed
13811387 if val_at_start and current_step == 0 :
@@ -1596,8 +1602,8 @@ def grpo_train(
15961602 # Calculate rewards & advantages
15971603 memory_tracker .snapshot_start_of_stage ("Processing rewards" , dir ())
15981604 print ("▶ Processing rewards...," , flush = True )
1599- # GDPO
16001605 with timer .time ("reward_calculation" ):
1606+ # Extract rewards from final_batch
16011607 rewards = repeated_batch ["total_reward" ]
16021608 # Store input_ids in batch so that after dynamic_sampling it stays aligned with
16031609 # the (possibly filtered) batch: select_indices / from_batches / slice all
@@ -1788,8 +1794,6 @@ def grpo_train(
17881794 sample_mask = train_data ["sample_mask" ]
17891795 mask = token_mask * sample_mask .unsqueeze (- 1 )
17901796
1791-
1792-
17931797 train_data ["advantages" ] = adv_estimator .compute_advantage (
17941798 repeated_batch = repeated_batch ,
17951799 mask = mask ,
@@ -2445,7 +2449,7 @@ def async_grpo_train(
24452449 val_at_end = master_config ["grpo" ]["val_at_end" ]
24462450 colocated_inference = master_config ["policy" ]["generation" ]["colocated" ]["enabled" ]
24472451
2448- # Create advantage estimator
2452+ # Initialize advantage estimator
24492453 adv_estimator = _create_advantage_estimator (master_config )
24502454
24512455 assert not colocated_inference , (
0 commit comments