Skip to content

Commit 30e80b7

Browse files
author
Shih-Yang Liu
committed
move from my personal brannch to here
Signed-off-by: Shih-Yang Liu <nbasyl>
1 parent c4f8e1c commit 30e80b7

File tree

17 files changed

+774
-30
lines changed

17 files changed

+774
-30
lines changed

examples/configs/gdpo_math_1B.yaml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# GDPO: inherits from grpo_math_1B.yaml and overrides only what differs.
2+
defaults: grpo_math_1B.yaml
3+
4+
grpo:
5+
adv_estimator:
6+
name: "gdpo"
7+
normalize_rewards: true
8+
use_leave_one_out_baseline: false
9+
10+
checkpointing:
11+
checkpoint_dir: "results/gdpo"
12+
13+
policy:
14+
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
15+
logprob_batch_size: 4
16+
max_total_sequence_length: 1024
17+
megatron_cfg:
18+
optimizer:
19+
weight_decay: 0.0
20+
scheduler:
21+
lr_decay_style: "cosine"
22+
lr_warmup_iters: 10
23+
24+
# GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default.
25+
data:
26+
_override_: true
27+
28+
max_input_seq_length: ${policy.max_total_sequence_length}
29+
shuffle: true
30+
num_workers: 1
31+
32+
use_multiple_dataloader: false
33+
34+
train:
35+
dataset_name: "gsm8k"
36+
split: train
37+
validation:
38+
dataset_name: "gsm8k"
39+
split: test
40+
41+
default:
42+
prompt_file: null
43+
system_prompt_file: "examples/prompts/gsm8k.txt"
44+
processor: "math_gdpo_data_processor"
45+
env_name: "math_multi_reward"
46+
47+
env:
48+
math_multi_reward:
49+
num_workers: 8
50+
math_verify_impl: "hf_math_verify"
51+
52+
logger:
53+
wandb_enabled: true
54+
wandb:
55+
project: "gdpo-dev"
56+
name: "gdpo-dev-logger"
57+
swanlab:
58+
project: "gdpo-dev"
59+
name: "gdpo-dev-logger"
60+
mlflow:
61+
experiment_name: "gdpo-dev"
62+
run_name: "gdpo-dev-logger"

examples/prompts/gsm8k.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
You are a helpful AI assistant.
2+
3+
For every request, you should carefully think through the math problem step by step, then provide the final answer in integer format.
4+
5+
Steps for Each Request:
6+
1. Think: Provide detailed, step-by-step reasoning, calculations, or derivations.
7+
2. Produce Final Answer: After step-by-step reasoning, output the final answer in integer format.
8+
9+
Output Format:
10+
<think>Your thoughts and reasoning</think>
11+
<answer>Final answer in integer format</answer>
12+
13+
Important Notes:
14+
1. You must include your reasoning steps inside <think>.
15+
2. You must always output the Final Answer within <answer> after the reasoning steps is done.
16+
3. You should consistently work through the solution step by step before giving the final answer.
17+
4. The final answer can only be an integer.

examples/run_grpo.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ def main() -> None:
139139
"use_multiple_dataloader is not supported with async GRPO"
140140
)
141141

142+
# Async GDPO is not supported
143+
if config["grpo"]["adv_estimator"]["name"] == "gdpo":
144+
raise NotImplementedError(
145+
"GDPO is not supported for async training, "
146+
"please set grpo.async_grpo.enabled to false in your config."
147+
)
148+
142149
from nemo_rl.algorithms.grpo import async_grpo_train
143150

144151
print("🚀 Running async GRPO training")

nemo_rl/algorithms/advantage_estimator.py

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
1717
This module provides different advantage estimation strategies:
1818
- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline
19+
- GDPOAdvantageEstimator: Multi-reward GDPO (per-component baselines, sum then normalize)
1920
- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward
2021
Reference papers:
2122
- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/
@@ -24,8 +25,7 @@
2425

2526
import torch
2627

27-
from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl
28-
28+
from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl, get_gdpo_reward_component_keys
2929

3030
class GRPOAdvantageEstimator:
3131
"""GRPO-style advantage estimator with leave-one-out baseline.
@@ -37,14 +37,21 @@ def __init__(self, estimator_config: dict, loss_config: dict):
3737
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
3838
self.normalize_rewards = estimator_config["normalize_rewards"]
3939

40-
def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
40+
def compute_advantage(
41+
self,
42+
prompt_ids,
43+
rewards,
44+
repeated_batch,
45+
mask,
46+
**kwargs,
47+
):
4148
"""Compute GRPO advantages.
4249
4350
Args:
4451
prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
4552
rewards: Tensor of shape [batch_size] containing reward for each sample.
53+
repeated_batch: Batch (unused; for interface consistency).
4654
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
47-
Used only for expanding advantages to token-level shape.
4855
**kwargs: Additional arguments (unused).
4956
5057
Returns:
@@ -69,6 +76,83 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
6976
return advantages.expand(mask.shape)
7077

7178

79+
class GDPOAdvantageEstimator:
80+
"""GDPO-style advantage estimator with leave-one-out baseline.
81+
82+
Note: GDPO computes advantages for each reward separately over all responses for each prompt.
83+
"""
84+
85+
def __init__(self, estimator_config: dict, loss_config: dict):
86+
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
87+
self.normalize_rewards = estimator_config["normalize_rewards"]
88+
89+
def compute_advantage(
90+
self,
91+
prompt_ids,
92+
rewards,
93+
repeated_batch,
94+
mask,
95+
**kwargs,
96+
):
97+
"""Compute GDPO advantages.
98+
99+
Args:
100+
prompt_ids: Unused; for interface consistency.
101+
rewards: Unused; for interface consistency.
102+
repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys.
103+
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
104+
**kwargs: Additional arguments (unused).
105+
106+
Returns:
107+
Advantages tensor of shape [batch_size, seq_len].
108+
"""
109+
reward_component_keys = get_gdpo_reward_component_keys(repeated_batch)
110+
if len(reward_component_keys) < 2:
111+
raise ValueError(
112+
f"GDPO requires multiple reward components (reward1, reward2, ...). "
113+
f"This batch has {len(reward_component_keys)} component(s). "
114+
"Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config."
115+
)
116+
current_input_ids = repeated_batch["_input_ids_for_baseline"]
117+
valid = torch.ones_like(
118+
repeated_batch[reward_component_keys[0]]
119+
)
120+
leave_one_out = self.use_leave_one_out_baseline
121+
assert current_input_ids.shape[0] == valid.shape[0], (
122+
"_input_ids_for_baseline must match reward batch size after dynamic_sampling; "
123+
f"got {current_input_ids.shape[0]} vs {valid.shape[0]}"
124+
)
125+
advantage_parts = []
126+
for key in reward_component_keys:
127+
r = repeated_batch[key]
128+
base, std_k = calculate_baseline_and_std_per_prompt(
129+
current_input_ids,
130+
r,
131+
valid,
132+
leave_one_out_baseline=leave_one_out,
133+
)
134+
adv_k = (r - base).unsqueeze(-1)
135+
if self.normalize_rewards:
136+
137+
epsilon = 1e-6
138+
non_zero_std_mask = std_k > 0
139+
adv_k[non_zero_std_mask] = adv_k[non_zero_std_mask] / (
140+
std_k.unsqueeze(-1)[non_zero_std_mask] + epsilon
141+
)
142+
143+
advantage_parts.append(adv_k)
144+
145+
advantages = sum(advantage_parts)
146+
# Normalize combined advantage to zero mean and unit std
147+
adv_std = advantages.std()
148+
if adv_std > 0:
149+
advantages = (advantages - advantages.mean()) / adv_std
150+
else:
151+
advantages = advantages - advantages.mean()
152+
153+
return advantages.expand(mask.shape)
154+
155+
72156
class ReinforcePlusPlusAdvantageEstimator:
73157
"""Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward.
74158
@@ -87,6 +171,7 @@ def compute_advantage(
87171
self,
88172
prompt_ids,
89173
rewards,
174+
repeated_batch,
90175
mask,
91176
logprobs_policy=None,
92177
logprobs_reference=None,
@@ -95,13 +180,12 @@ def compute_advantage(
95180
"""Compute Reinforce++ advantages with optional KL penalty.
96181
97182
Args:
98-
prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
183+
prompt_ids: Tensor identifying which prompt each sample belongs to (for baseline).
99184
rewards: Tensor of shape [batch_size] containing reward for each sample.
185+
repeated_batch: Unused; for interface consistency.
100186
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
101-
Used for: (1) expanding advantages to token-level shape, (2) global normalization
102-
that only considers valid tokens.
103-
logprobs_policy: Policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward.
104-
logprobs_reference: Reference policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward.
187+
logprobs_policy: Policy log probabilities, required if use_kl_in_reward.
188+
logprobs_reference: Reference policy log probabilities, required if use_kl_in_reward.
105189
**kwargs: Additional arguments (unused).
106190
107191
Returns:

nemo_rl/algorithms/grpo.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import gc
1515
import os
16+
import re
1617
import time
1718
import warnings
1819
from concurrent.futures import ThreadPoolExecutor
@@ -29,6 +30,7 @@
2930

3031
from nemo_rl.algorithms.advantage_estimator import (
3132
GRPOAdvantageEstimator,
33+
GDPOAdvantageEstimator,
3234
ReinforcePlusPlusAdvantageEstimator,
3335
)
3436
from nemo_rl.algorithms.loss import (
@@ -46,6 +48,7 @@
4648
log_generation_metrics_to_wandb,
4749
print_performance_metrics,
4850
set_seed,
51+
get_gdpo_reward_component_keys
4952
)
5053
from nemo_rl.data import DataConfig
5154
from nemo_rl.data.collate_fn import rl_collate_fn
@@ -121,9 +124,9 @@ class AsyncGRPOConfig(TypedDict):
121124

122125

123126
class AdvEstimatorConfig(TypedDict):
124-
"""Configuration for advantage estimator (GRPO or Reinforce++)."""
127+
"""Configuration for advantage estimator (GRPO, GDPO, or Reinforce++)."""
125128

126-
name: str # "grpo" or "reinforce_plus_plus"
129+
name: str # "grpo", "gdpo", or "reinforce_plus_plus"
127130
# GRPO specific
128131
normalize_rewards: NotRequired[bool]
129132
use_leave_one_out_baseline: NotRequired[bool]
@@ -966,11 +969,16 @@ def scale_rewards(
966969
)
967970

968971
# Clamp and scale
969-
rewards = torch.clamp(rewards, min=source_min, max=source_max)
970-
scaled_rewards = target_min + (rewards - source_min) / (
971-
source_max - source_min
972-
) * (target_max - target_min)
972+
def _scale(reward_tensor: torch.Tensor) -> torch.Tensor:
973+
r = torch.clamp(reward_tensor, min=source_min, max=source_max)
974+
return target_min + (r - source_min) / (
975+
source_max - source_min
976+
) * (target_max - target_min)
977+
978+
scaled_rewards = _scale(rewards)
973979
repeated_batch["total_reward"] = scaled_rewards
980+
for key in get_gdpo_reward_component_keys(repeated_batch):
981+
repeated_batch[key] = _scale(repeated_batch[key])
974982

975983
return repeated_batch
976984

@@ -1031,7 +1039,7 @@ def _create_advantage_estimator(master_config: MasterConfig):
10311039
master_config: The master configuration dictionary.
10321040
10331041
Returns:
1034-
An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator).
1042+
An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus).
10351043
10361044
Raises:
10371045
ValueError: If the advantage estimator name is not recognized.
@@ -1055,7 +1063,14 @@ def _create_advantage_estimator(master_config: MasterConfig):
10551063
)
10561064

10571065
adv_estimator_name = adv_estimator_config["name"]
1058-
if adv_estimator_name == "grpo":
1066+
if adv_estimator_name == "gdpo":
1067+
assert not _should_use_async_rollouts(master_config), (
1068+
"GDPO is not supported for async rollouts, "
1069+
"please set policy.generation.vllm_cfg.async_engine to false in your config."
1070+
)
1071+
adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config)
1072+
print(" ✓ Using GDPO advantage estimator (multi-reward)")
1073+
elif adv_estimator_name == "grpo":
10591074
adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config)
10601075
print(" ✓ Using GRPO advantage estimator")
10611076
elif adv_estimator_name == "reinforce_plus_plus":
@@ -1590,6 +1605,10 @@ def grpo_train(
15901605
with timer.time("reward_calculation"):
15911606
# Extract rewards from final_batch
15921607
rewards = repeated_batch["total_reward"]
1608+
# Store input_ids in batch so that after dynamic_sampling it stays aligned with
1609+
# the (possibly filtered) batch: select_indices / from_batches / slice all
1610+
# apply to this key, so per-reward baselines use the same prompts as reward components.
1611+
repeated_batch["_input_ids_for_baseline"] = input_ids
15931612

15941613
print("▶ Computing advantages...", flush=True)
15951614
if master_config["grpo"].get("calculate_advantages_on_gpu"):
@@ -1644,10 +1663,10 @@ def grpo_train(
16441663
# If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch.
16451664
if not is_batch_complete:
16461665
continue
1666+
16471667
gen_step_metrics = {}
16481668
if hasattr(policy_generation, "get_step_metrics"):
16491669
gen_step_metrics = policy_generation.get_step_metrics()
1650-
advantages = (rewards - baseline).unsqueeze(-1)
16511670

16521671
# Save baseline for logging (before deletion)
16531672
baseline_for_log = baseline.clone()
@@ -1778,6 +1797,7 @@ def grpo_train(
17781797
train_data["advantages"] = adv_estimator.compute_advantage(
17791798
prompt_ids=prompt_ids_for_adv,
17801799
rewards=rewards,
1800+
repeated_batch=repeated_batch,
17811801
mask=mask,
17821802
logprobs_policy=train_data["prev_logprobs"],
17831803
logprobs_reference=train_data.get("reference_policy_logprobs"),
@@ -2724,6 +2744,8 @@ def async_grpo_train(
27242744
del prompt_batched_flat
27252745

27262746
rewards = repeated_batch["total_reward"]
2747+
# All estimators read _input_ids_for_baseline from repeated_batch
2748+
repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv
27272749

27282750
print(
27292751
f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}"
@@ -2809,6 +2831,7 @@ def async_grpo_train(
28092831
train_data["advantages"] = adv_estimator.compute_advantage(
28102832
prompt_ids=prompt_ids_for_adv,
28112833
rewards=rewards,
2834+
repeated_batch=repeated_batch,
28122835
mask=mask,
28132836
logprobs_policy=train_data["prev_logprobs"],
28142837
logprobs_reference=train_data.get("reference_policy_logprobs"),

nemo_rl/algorithms/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES
3030
from nemo_rl.models.policy import TokenizerConfig
3131
from nemo_rl.utils.logger import Logger
32+
import re
3233

34+
def get_gdpo_reward_component_keys(batch) -> list:
35+
"""Return batch keys that are reward components (reward1, reward2, ...) in sorted order."""
36+
keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))]
37+
return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group()))
3338

3439
def calculate_kl(
3540
logprobs: torch.Tensor,

0 commit comments

Comments
 (0)