Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/configs/gdpo_math_1B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# GDPO: inherits from grpo_math_1B.yaml and overrides only what differs.
defaults: grpo_math_1B.yaml

grpo:
adv_estimator:
name: "gdpo"
normalize_rewards: true
use_leave_one_out_baseline: false

checkpointing:
checkpoint_dir: "results/gdpo"

policy:
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
logprob_batch_size: 4
max_total_sequence_length: 1024
megatron_cfg:
optimizer:
weight_decay: 0.0
scheduler:
lr_decay_style: "cosine"
lr_warmup_iters: 10

# GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default.
data:
_override_: true

max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true
num_workers: 1

use_multiple_dataloader: false

train:
dataset_name: "gsm8k"
split: train
validation:
dataset_name: "gsm8k"
split: test

default:
prompt_file: null
system_prompt_file: "examples/prompts/gsm8k.txt"
processor: "math_gdpo_data_processor"
env_name: "math_multi_reward"

env:
math_multi_reward:
num_workers: 8
math_verify_impl: "hf_math_verify"

logger:
wandb_enabled: true
wandb:
project: "gdpo-dev"
name: "gdpo-dev-logger"
swanlab:
project: "gdpo-dev"
name: "gdpo-dev-logger"
mlflow:
experiment_name: "gdpo-dev"
run_name: "gdpo-dev-logger"
17 changes: 17 additions & 0 deletions examples/prompts/gsm8k.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
You are a helpful AI assistant.

For every request, you should carefully think through the math problem step by step, then provide the final answer in integer format.

Steps for Each Request:
1. Think: Provide detailed, step-by-step reasoning, calculations, or derivations.
2. Produce Final Answer: After step-by-step reasoning, output the final answer in integer format.

Output Format:
<think>Your thoughts and reasoning</think>
<answer>Final answer in integer format</answer>

Important Notes:
1. You must include your reasoning steps inside <think>.
2. You must always output the Final Answer within <answer> after the reasoning steps is done.
3. You should consistently work through the solution step by step before giving the final answer.
4. The final answer can only be an integer.
7 changes: 7 additions & 0 deletions examples/run_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def main() -> None:
"use_multiple_dataloader is not supported with async GRPO"
)

# Async GDPO is not supported
if config["grpo"]["adv_estimator"]["name"] == "gdpo":
raise NotImplementedError(
"GDPO is not supported for async training, "
"please set grpo.async_grpo.enabled to false in your config."
)

from nemo_rl.algorithms.grpo import async_grpo_train

print("🚀 Running async GRPO training")
Expand Down
101 changes: 93 additions & 8 deletions nemo_rl/algorithms/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

This module provides different advantage estimation strategies:
- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline
- GDPOAdvantageEstimator: Multi-reward GDPO (per-component baselines, sum then normalize)
- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward
Reference papers:
- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/
Expand All @@ -24,7 +25,11 @@

import torch

from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl
from nemo_rl.algorithms.utils import (
calculate_baseline_and_std_per_prompt,
calculate_kl,
get_gdpo_reward_component_keys,
)


class GRPOAdvantageEstimator:
Expand All @@ -37,14 +42,21 @@ def __init__(self, estimator_config: dict, loss_config: dict):
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
self.normalize_rewards = estimator_config["normalize_rewards"]

def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
def compute_advantage(
self,
prompt_ids,
rewards,
repeated_batch,
mask,
**kwargs,
):
"""Compute GRPO advantages.

Args:
prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
rewards: Tensor of shape [batch_size] containing reward for each sample.
repeated_batch: Batch (unused; for interface consistency).
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
Used only for expanding advantages to token-level shape.
**kwargs: Additional arguments (unused).

Returns:
Expand All @@ -69,6 +81,79 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
return advantages.expand(mask.shape)


class GDPOAdvantageEstimator:
"""GDPO-style advantage estimator with leave-one-out baseline.

Note: GDPO computes advantages for each reward separately over all responses for each prompt.
"""

def __init__(self, estimator_config: dict, loss_config: dict):
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
self.normalize_rewards = estimator_config["normalize_rewards"]

def compute_advantage(
self,
prompt_ids,
rewards,
repeated_batch,
mask,
**kwargs,
):
"""Compute GDPO advantages.

Args:
prompt_ids: Tensor identifying which prompt each sample belongs to (for per-prompt baselines).
rewards: Unused; for interface consistency.
repeated_batch: Batch containing reward1, reward2, ... keys.
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
**kwargs: Additional arguments (unused).

Returns:
Advantages tensor of shape [batch_size, seq_len].
"""
reward_component_keys = get_gdpo_reward_component_keys(repeated_batch)
if len(reward_component_keys) < 2:
raise ValueError(
f"GDPO requires multiple reward components (reward1, reward2, ...). "
f"This batch has {len(reward_component_keys)} component(s). "
"Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config."
)
valid = torch.ones_like(repeated_batch[reward_component_keys[0]])
leave_one_out = self.use_leave_one_out_baseline
assert prompt_ids.shape[0] == valid.shape[0], (
"prompt_ids must match reward batch size; "
f"got {prompt_ids.shape[0]} vs {valid.shape[0]}"
)
advantage_parts = []
for key in reward_component_keys:
r = repeated_batch[key]
base, std_k = calculate_baseline_and_std_per_prompt(
prompt_ids,
r,
valid,
leave_one_out_baseline=leave_one_out,
)
adv_k = (r - base).unsqueeze(-1)
if self.normalize_rewards:
epsilon = 1e-6
non_zero_std_mask = std_k > 0
adv_k[non_zero_std_mask] = adv_k[non_zero_std_mask] / (
std_k.unsqueeze(-1)[non_zero_std_mask] + epsilon
)

advantage_parts.append(adv_k)

advantages = sum(advantage_parts)
# Normalize combined advantage to zero mean and unit std
adv_std = advantages.std()
if adv_std > 0:
advantages = (advantages - advantages.mean()) / adv_std
else:
advantages = advantages - advantages.mean()

return advantages.expand(mask.shape)


class ReinforcePlusPlusAdvantageEstimator:
"""Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward.

Expand All @@ -87,6 +172,7 @@ def compute_advantage(
self,
prompt_ids,
rewards,
repeated_batch,
mask,
logprobs_policy=None,
logprobs_reference=None,
Expand All @@ -95,13 +181,12 @@ def compute_advantage(
"""Compute Reinforce++ advantages with optional KL penalty.

Args:
prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
prompt_ids: Tensor identifying which prompt each sample belongs to (for baseline).
rewards: Tensor of shape [batch_size] containing reward for each sample.
repeated_batch: Unused; for interface consistency.
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
Used for: (1) expanding advantages to token-level shape, (2) global normalization
that only considers valid tokens.
logprobs_policy: Policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward.
logprobs_reference: Reference policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward.
logprobs_policy: Policy log probabilities, required if use_kl_in_reward.
logprobs_reference: Reference policy log probabilities, required if use_kl_in_reward.
**kwargs: Additional arguments (unused).

Returns:
Expand Down
34 changes: 25 additions & 9 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from nemo_rl.algorithms.advantage_estimator import (
GDPOAdvantageEstimator,
GRPOAdvantageEstimator,
ReinforcePlusPlusAdvantageEstimator,
)
Expand All @@ -43,6 +44,7 @@
)
from nemo_rl.algorithms.utils import (
calculate_baseline_and_std_per_prompt,
get_gdpo_reward_component_keys,
log_generation_metrics_to_wandb,
print_performance_metrics,
set_seed,
Expand Down Expand Up @@ -121,9 +123,9 @@ class AsyncGRPOConfig(TypedDict):


class AdvEstimatorConfig(TypedDict):
"""Configuration for advantage estimator (GRPO or Reinforce++)."""
"""Configuration for advantage estimator (GRPO, GDPO, or Reinforce++)."""

name: str # "grpo" or "reinforce_plus_plus"
name: str # "grpo", "gdpo", or "reinforce_plus_plus"
# GRPO specific
normalize_rewards: NotRequired[bool]
use_leave_one_out_baseline: NotRequired[bool]
Expand Down Expand Up @@ -966,11 +968,16 @@ def scale_rewards(
)

# Clamp and scale
rewards = torch.clamp(rewards, min=source_min, max=source_max)
scaled_rewards = target_min + (rewards - source_min) / (
source_max - source_min
) * (target_max - target_min)
def _scale(reward_tensor: torch.Tensor) -> torch.Tensor:
r = torch.clamp(reward_tensor, min=source_min, max=source_max)
return target_min + (r - source_min) / (source_max - source_min) * (
target_max - target_min
)

scaled_rewards = _scale(rewards)
repeated_batch["total_reward"] = scaled_rewards
for key in get_gdpo_reward_component_keys(repeated_batch):
repeated_batch[key] = _scale(repeated_batch[key])

return repeated_batch

Expand Down Expand Up @@ -1031,7 +1038,7 @@ def _create_advantage_estimator(master_config: MasterConfig):
master_config: The master configuration dictionary.

Returns:
An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator).
An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus).

Raises:
ValueError: If the advantage estimator name is not recognized.
Expand All @@ -1055,7 +1062,14 @@ def _create_advantage_estimator(master_config: MasterConfig):
)

adv_estimator_name = adv_estimator_config["name"]
if adv_estimator_name == "grpo":
if adv_estimator_name == "gdpo":
assert not _should_use_async_rollouts(master_config), (
"GDPO is not supported for async rollouts, "
"please set policy.generation.vllm_cfg.async_engine to false in your config."
)
adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config)
print(" ✓ Using GDPO advantage estimator (multi-reward)")
elif adv_estimator_name == "grpo":
adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config)
print(" ✓ Using GRPO advantage estimator")
elif adv_estimator_name == "reinforce_plus_plus":
Expand Down Expand Up @@ -1644,10 +1658,10 @@ def grpo_train(
# If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch.
if not is_batch_complete:
continue

gen_step_metrics = {}
if hasattr(policy_generation, "get_step_metrics"):
gen_step_metrics = policy_generation.get_step_metrics()
advantages = (rewards - baseline).unsqueeze(-1)

# Save baseline for logging (before deletion)
baseline_for_log = baseline.clone()
Expand Down Expand Up @@ -1778,6 +1792,7 @@ def grpo_train(
train_data["advantages"] = adv_estimator.compute_advantage(
prompt_ids=prompt_ids_for_adv,
rewards=rewards,
repeated_batch=repeated_batch,
mask=mask,
logprobs_policy=train_data["prev_logprobs"],
logprobs_reference=train_data.get("reference_policy_logprobs"),
Expand Down Expand Up @@ -2809,6 +2824,7 @@ def async_grpo_train(
train_data["advantages"] = adv_estimator.compute_advantage(
prompt_ids=prompt_ids_for_adv,
rewards=rewards,
repeated_batch=repeated_batch,
mask=mask,
logprobs_policy=train_data["prev_logprobs"],
logprobs_reference=train_data.get("reference_policy_logprobs"),
Expand Down
7 changes: 7 additions & 0 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import math
import random
import re
import warnings
from functools import partial, wraps
from typing import Any, Optional
Expand All @@ -31,6 +32,12 @@
from nemo_rl.utils.logger import Logger


def get_gdpo_reward_component_keys(batch) -> list:
"""Return batch keys that are reward components (reward1, reward2, ...) in sorted order."""
keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))]
return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group()))


def calculate_kl(
logprobs: torch.Tensor,
logprobs_reference: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/data/datasets/response_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GeneralConversationsJsonlDataset,
)
from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset
from nemo_rl.data.datasets.response_datasets.gsm8k import GSM8KDataset
from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset
from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset
from nemo_rl.data.datasets.response_datasets.oai_format_dataset import (
Expand Down Expand Up @@ -55,6 +56,7 @@
"refcoco": RefCOCODataset,
"squad": SquadDataset,
"tulu3_sft_mixture": Tulu3SftMixtureDataset,
"gsm8k": GSM8KDataset,
# load from local JSONL file or HuggingFace
"openai_format": OpenAIFormatDataset,
"NemoGymDataset": NemoGymDataset,
Expand Down Expand Up @@ -94,6 +96,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig):
"GeneralConversationsJsonlDataset",
"DAPOMath17KDataset",
"DAPOMathAIME2024Dataset",
"GSM8KDataset",
"DeepScalerDataset",
"Geometry3KDataset",
"HelpSteer3Dataset",
Expand Down
Loading
Loading