1313# limitations under the License.
1414import gc
1515import os
16+ import re
1617import time
1718import warnings
1819from concurrent .futures import ThreadPoolExecutor
2930
3031from nemo_rl .algorithms .advantage_estimator import (
3132 GRPOAdvantageEstimator ,
33+ GDPOAdvantageEstimator ,
3234 ReinforcePlusPlusAdvantageEstimator ,
3335)
3436from nemo_rl .algorithms .loss import (
4648 log_generation_metrics_to_wandb ,
4749 print_performance_metrics ,
4850 set_seed ,
51+ get_gdpo_reward_component_keys
4952)
5053from nemo_rl .data import DataConfig
5154from nemo_rl .data .collate_fn import rl_collate_fn
@@ -121,9 +124,9 @@ class AsyncGRPOConfig(TypedDict):
121124
122125
123126class AdvEstimatorConfig (TypedDict ):
124- """Configuration for advantage estimator (GRPO or Reinforce++)."""
127+ """Configuration for advantage estimator (GRPO, GDPO, or Reinforce++)."""
125128
126- name : str # "grpo" or "reinforce_plus_plus"
129+ name : str # "grpo", "gdpo", or "reinforce_plus_plus"
127130 # GRPO specific
128131 normalize_rewards : NotRequired [bool ]
129132 use_leave_one_out_baseline : NotRequired [bool ]
@@ -966,11 +969,16 @@ def scale_rewards(
966969 )
967970
968971 # Clamp and scale
969- rewards = torch .clamp (rewards , min = source_min , max = source_max )
970- scaled_rewards = target_min + (rewards - source_min ) / (
971- source_max - source_min
972- ) * (target_max - target_min )
972+ def _scale (reward_tensor : torch .Tensor ) -> torch .Tensor :
973+ r = torch .clamp (reward_tensor , min = source_min , max = source_max )
974+ return target_min + (r - source_min ) / (
975+ source_max - source_min
976+ ) * (target_max - target_min )
977+
978+ scaled_rewards = _scale (rewards )
973979 repeated_batch ["total_reward" ] = scaled_rewards
980+ for key in get_gdpo_reward_component_keys (repeated_batch ):
981+ repeated_batch [key ] = _scale (repeated_batch [key ])
974982
975983 return repeated_batch
976984
@@ -1031,7 +1039,7 @@ def _create_advantage_estimator(master_config: MasterConfig):
10311039 master_config: The master configuration dictionary.
10321040
10331041 Returns:
1034- An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator ).
1042+ An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus ).
10351043
10361044 Raises:
10371045 ValueError: If the advantage estimator name is not recognized.
@@ -1055,7 +1063,14 @@ def _create_advantage_estimator(master_config: MasterConfig):
10551063 )
10561064
10571065 adv_estimator_name = adv_estimator_config ["name" ]
1058- if adv_estimator_name == "grpo" :
1066+ if adv_estimator_name == "gdpo" :
1067+ assert not _should_use_async_rollouts (master_config ), (
1068+ "GDPO is not supported for async rollouts, "
1069+ "please set policy.generation.vllm_cfg.async_engine to false in your config."
1070+ )
1071+ adv_estimator = GDPOAdvantageEstimator (adv_estimator_config , loss_config )
1072+ print (" ✓ Using GDPO advantage estimator (multi-reward)" )
1073+ elif adv_estimator_name == "grpo" :
10591074 adv_estimator = GRPOAdvantageEstimator (adv_estimator_config , loss_config )
10601075 print (" ✓ Using GRPO advantage estimator" )
10611076 elif adv_estimator_name == "reinforce_plus_plus" :
@@ -1590,6 +1605,10 @@ def grpo_train(
15901605 with timer .time ("reward_calculation" ):
15911606 # Extract rewards from final_batch
15921607 rewards = repeated_batch ["total_reward" ]
1608+ # Store input_ids in batch so that after dynamic_sampling it stays aligned with
1609+ # the (possibly filtered) batch: select_indices / from_batches / slice all
1610+ # apply to this key, so per-reward baselines use the same prompts as reward components.
1611+ repeated_batch ["_input_ids_for_baseline" ] = input_ids
15931612
15941613 print ("▶ Computing advantages..." , flush = True )
15951614 if master_config ["grpo" ].get ("calculate_advantages_on_gpu" ):
@@ -1644,10 +1663,10 @@ def grpo_train(
16441663 # If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch.
16451664 if not is_batch_complete :
16461665 continue
1666+
16471667 gen_step_metrics = {}
16481668 if hasattr (policy_generation , "get_step_metrics" ):
16491669 gen_step_metrics = policy_generation .get_step_metrics ()
1650- advantages = (rewards - baseline ).unsqueeze (- 1 )
16511670
16521671 # Save baseline for logging (before deletion)
16531672 baseline_for_log = baseline .clone ()
@@ -1778,6 +1797,7 @@ def grpo_train(
17781797 train_data ["advantages" ] = adv_estimator .compute_advantage (
17791798 prompt_ids = prompt_ids_for_adv ,
17801799 rewards = rewards ,
1800+ repeated_batch = repeated_batch ,
17811801 mask = mask ,
17821802 logprobs_policy = train_data ["prev_logprobs" ],
17831803 logprobs_reference = train_data .get ("reference_policy_logprobs" ),
@@ -2724,6 +2744,8 @@ def async_grpo_train(
27242744 del prompt_batched_flat
27252745
27262746 rewards = repeated_batch ["total_reward" ]
2747+ # All estimators read _input_ids_for_baseline from repeated_batch
2748+ repeated_batch ["_input_ids_for_baseline" ] = prompt_ids_for_adv
27272749
27282750 print (
27292751 f" 📊 Rewards stats: min={ rewards .min ():.4f} , max={ rewards .max ():.4f} , mean={ rewards .mean ():.4f} , std={ rewards .std ():.4f} "
@@ -2809,6 +2831,7 @@ def async_grpo_train(
28092831 train_data ["advantages" ] = adv_estimator .compute_advantage (
28102832 prompt_ids = prompt_ids_for_adv ,
28112833 rewards = rewards ,
2834+ repeated_batch = repeated_batch ,
28122835 mask = mask ,
28132836 logprobs_policy = train_data ["prev_logprobs" ],
28142837 logprobs_reference = train_data .get ("reference_policy_logprobs" ),
0 commit comments