Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
751d995
feat: Support DAPO dynamic sampling and reward shaping
peri044 Jul 3, 2025
e5dd193
Merge branch 'main' into dapo
peri044 Jul 3, 2025
d3ec46f
-s
peri044 Aug 10, 2025
7b9226b
feat: Add DAPO implementation
peri044 Aug 10, 2025
d6cd81d
chore: revert DAPO changes in GRPO files
peri044 Aug 10, 2025
108e1e0
chore: move DAPO impl back to GRPO
peri044 Aug 12, 2025
f249641
chore: remove dapo scripts
peri044 Aug 12, 2025
17b519c
Merge pull request #1 from peri044/dapo_merge
peri044 Aug 12, 2025
0d7f9bf
fix: Fix config bug
peri044 Aug 12, 2025
15fb6dc
chore: add timer for dynamic sampling and log filtered and total rewards
peri044 Aug 13, 2025
edda8d4
chore: add dapo-math dataset, address review comments
peri044 Aug 15, 2025
53c004f
chore: add DAPO-MATH-17k config
peri044 Aug 15, 2025
65df6b6
chore: update dapo-math config
peri044 Aug 15, 2025
25c9d8b
chore: update config
peri044 Aug 15, 2025
e4431ac
chore: remove the dapo deepscalar config
peri044 Aug 18, 2025
d99e460
chore: add reward scaling
peri044 Aug 20, 2025
77d08ff
chore: change start_factor to 0.1
peri044 Aug 20, 2025
a33fad2
chore: update dapo config
peri044 Aug 21, 2025
f55c9f7
chore: update gpus_per_node
peri044 Aug 21, 2025
57b0544
chore: rebase with main
peri044 Aug 21, 2025
cff451f
chore: fix prompt_file issue and disable leave_one_out baseline
peri044 Aug 21, 2025
f2a9d67
chore: address review comments
peri044 Aug 26, 2025
4f5ce08
chore: delete approach 1 of DS
peri044 Aug 26, 2025
dddfb03
chore: use verl's math verifier
peri044 Aug 27, 2025
4000965
chore: update math env to use dapo_math_verifier
peri044 Aug 27, 2025
f5de167
chore: disable reward scaling when dapo_math_verifier is used
peri044 Aug 27, 2025
ac228d2
chore: update reward calculation
peri044 Aug 27, 2025
14204e4
Merge branch 'main' into dapo
peri044 Aug 28, 2025
6981357
chore: additional updates
peri044 Aug 29, 2025
cd1cbdf
chore: minor updates and add megatron config
peri044 Aug 29, 2025
3717c9d
chore: address review comments
peri044 Aug 29, 2025
88716fd
chore: remove dapo-math.yaml
peri044 Aug 29, 2025
d4fb90f
chore: address review comments and add reward shaping fix
peri044 Sep 17, 2025
b59d705
chore: rebase with main
peri044 Sep 17, 2025
fabe873
chore: restore to original state
peri044 Sep 17, 2025
8a4e12d
chore: update docstring
peri044 Oct 2, 2025
4b1f9a5
chore: rebase with main and resolve conflicts
peri044 Oct 2, 2025
2b885b8
chore: fix filtered reward metric
peri044 Oct 4, 2025
afec799
add unit tests
ashors1 Oct 7, 2025
ae6b96a
add recipe and short convergence test
ashors1 Oct 7, 2025
21cce62
clean up config
ashors1 Oct 7, 2025
c8bdae8
chore: minor updates to logger data
peri044 Oct 7, 2025
54ec885
chore: rebase
peri044 Oct 7, 2025
6d0c521
chore: updates
peri044 Oct 7, 2025
8abeabf
chore: update doc
peri044 Oct 7, 2025
06544f3
chore: update docs with intermediate results
peri044 Oct 7, 2025
32c636c
chore: address review comments
peri044 Oct 7, 2025
a73c22f
address comment, remove some redundant tests
ashors1 Oct 7, 2025
5d6a46f
Merge pull request #2 from peri044/ashors/dapo-tests
peri044 Oct 7, 2025
ed06fe0
fix configs
ashors1 Oct 8, 2025
72ebf0b
fix test name
ashors1 Oct 8, 2025
b4ee6f7
add dapo to mapping
ashors1 Oct 8, 2025
5e6a251
chore: Fix a test case
peri044 Oct 8, 2025
1d156db
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 Oct 8, 2025
703932b
Merge branch 'main' into dapo
peri044 Oct 8, 2025
6b8c945
chore: use kwargs in math verifier to be consistent
peri044 Oct 8, 2025
564ef19
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 Oct 8, 2025
2efd601
Merge pull request #3 from peri044/ashors/dapo-fix-unit-tests
peri044 Oct 8, 2025
0a09fac
Merge branch 'main' into dapo
peri044 Oct 8, 2025
7832e93
chore: update license for dapo math verifier
peri044 Oct 8, 2025
591e3e8
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 Oct 8, 2025
22e78e7
dapo convergence test fixes
ashors1 Oct 9, 2025
fdc0b8c
Merge branch 'main' into dapo
peri044 Oct 9, 2025
1676108
Merge branch 'main' into dapo
peri044 Oct 9, 2025
65f33bf
make hf_overrides configurable
ashors1 Oct 9, 2025
a43177e
propagate hf overrides to vllm
ashors1 Oct 10, 2025
1008fda
default to None
ashors1 Oct 10, 2025
5256c81
minor fix
ashors1 Oct 10, 2025
1d8de6c
add copyright
ashors1 Oct 10, 2025
3967835
lint
ashors1 Oct 10, 2025
ef78bfb
Merge pull request #4 from peri044/ashors/dapo-fix-conv-test
ashors1 Oct 10, 2025
2668f4b
Merge branch 'main' into dapo
peri044 Oct 10, 2025
3371bbd
chore: fix linter check
peri044 Oct 10, 2025
d529897
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 Oct 10, 2025
b7bc4d7
fix hf_overrides default
ashors1 Oct 10, 2025
44a3e3c
fix hf_overrides during model conversion
ashors1 Oct 10, 2025
1ab98c1
document configs, add missing defaults
ashors1 Oct 10, 2025
fe7c46e
address some of Terry's comments
ashors1 Oct 10, 2025
234a5fb
chore: address some review comments
peri044 Oct 11, 2025
329fe1c
Merge pull request #5 from peri044/ashors/dapo-comments
peri044 Oct 11, 2025
939189b
chore: rebase
peri044 Oct 11, 2025
8fe6c20
chore: fix conflicts in README.md
peri044 Oct 11, 2025
298aad3
chore: add dapo guide link
peri044 Oct 11, 2025
f39a3df
chore: fix incorrect algo name
peri044 Oct 11, 2025
a5f3eea
chore: address review comments
peri044 Oct 11, 2025
36b72e0
chore: update logging of % of prompts that are being discarded
peri044 Oct 11, 2025
5399de8
chore: update non_zero_std_fraction metric in wandb
peri044 Oct 11, 2025
3effac0
chore: update doc
peri044 Oct 11, 2025
bf3a635
chore: update doc
peri044 Oct 11, 2025
521949b
chore: fix eqn format
peri044 Oct 11, 2025
d5063b3
chore: update scale_rewards function
peri044 Oct 12, 2025
c1ffc71
update logging, update docs, rename dapo_batch_multiplier, fix some f…
ashors1 Oct 12, 2025
cd48dc9
fix checkpointing config
ashors1 Oct 12, 2025
e920e09
chore: update docstring and add test cases for reward scaling
peri044 Oct 13, 2025
df553f1
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 Oct 13, 2025
9d769e9
Merge branch 'main' into dapo
peri044 Oct 13, 2025
8d34524
chore: directly access reward_scaling_cfg
peri044 Oct 13, 2025
21e5493
Merge branch 'main' into dapo
peri044 Oct 13, 2025
6107c62
Merge branch 'main' into dapo
peri044 Oct 14, 2025
00fae2e
chore: rebase
peri044 Oct 15, 2025
49f4ec4
tighter bounds for test, minor doc improvement
ashors1 Oct 15, 2025
2573f9d
Update docs/guides/dapo.md
peri044 Oct 15, 2025
84dc622
Update docs/guides/dapo.md
peri044 Oct 15, 2025
b2c29ec
Update docs/guides/dapo.md
peri044 Oct 15, 2025
d1eb2c4
chore: address review comments and rename max_num_gen_batches and num…
peri044 Oct 15, 2025
e1a3ed3
chore: Add dapo guide to index.md
peri044 Oct 15, 2025
1229358
fix: Fix config issue
peri044 Oct 16, 2025
02e0546
Merge branch 'main' into dapo
peri044 Oct 16, 2025
a3f0c3a
Merge branch 'main' into dapo
peri044 Oct 16, 2025
f297185
add missing key
ashors1 Oct 16, 2025
c9dcdb0
fix unit test
ashors1 Oct 16, 2025
302dd31
fix config issue
ashors1 Oct 16, 2025
f25b48b
fix distillation unit test
ashors1 Oct 16, 2025
769f11f
fix grpo unit test
ashors1 Oct 16, 2025
9bb2a9c
fix default hf_config_overrides
ashors1 Oct 17, 2025
02a8e1e
fix remaining hf_overrides default
ashors1 Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from pathlib import Path
from typing import Any, Optional, Tuple, TypedDict, TypeVar, cast
from typing import Any, NotRequired, Optional, Tuple, TypedDict, TypeVar, cast

import numpy as np
import ray
Expand All @@ -27,6 +27,7 @@
ClippedPGLossDataDict,
ClippedPGLossFn,
)
from nemo_rl.algorithms.reward_functions import RewardConfig, process_rewards
from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt
from nemo_rl.data import DataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn
Expand Down Expand Up @@ -83,6 +84,8 @@ class GRPOConfig(TypedDict):
val_at_start: bool
max_val_samples: int
checkpoint_dir: str
use_dynamic_sampling: NotRequired[bool]
max_num_gen_batches: NotRequired[int]


class GRPOSaveState(TypedDict):
Expand All @@ -106,6 +109,7 @@ class GRPOLoggerConfig(LoggerConfig):
class MasterConfig(TypedDict):
policy: PolicyConfig
loss_fn: ClippedPGLossConfig
reward_fn: RewardConfig
env: dict[str, Any]
data: DataConfig
grpo: GRPOConfig
Expand Down Expand Up @@ -518,6 +522,7 @@ def grpo_train(
logger.log_metrics(val_metrics, step, prefix="validation")
logger.log_metrics(validation_timings, step, prefix="timing/validation")

num_gen_batches = 0
# Run grpo training (single-turn)
batch: BatchedDataDict[DatumSpec]
for batch in dataloader:
Expand Down Expand Up @@ -555,6 +560,7 @@ def grpo_train(
else:
policy_generation.prepare_for_generation()

num_gen_batches += 1
with timer.time("generation"):
# Use async rollouts if vLLM async engine is enabled
if _should_use_async_rollouts(master_config):
Expand Down Expand Up @@ -601,8 +607,76 @@ def grpo_train(
"use_leave_one_out_baseline"
],
)

# Dynamic sampling algorithm (used in DAPO algorithm)
# This block implements dynamic sampling by selecting prompt groups with non-zero std.
# If sampled prompts are fewer than train_batch_size * num_generations_per_prompt, continue sampling until max_num_gen_batches is reached.
if master_config["grpo"]["use_dynamic_sampling"]:
std_chunks_per_prompt = std.split(
master_config["grpo"]["num_generations_per_prompt"]
)
keep_prompt_indices = []
selected_std_chunks = []
for chunk_idx, chunk in enumerate(std_chunks_per_prompt):
chunk_length = chunk.shape[0]
if torch.nonzero(chunk).shape[0] == chunk_length:
chunk_prompt_indices = [
chunk_idx * chunk_length + idx
for idx in range(chunk_length)
]
keep_prompt_indices.extend(chunk_prompt_indices)
selected_std_chunks.append(chunk)
std = torch.cat(selected_std_chunks)

generation_sample_buffer_size = len(keep_prompt_indices)
train_prompts_buffer_size = (
master_config["policy"]["train_global_batch_size"]
* master_config["grpo"]["num_generations_per_prompt"]
)

# If the generation samples size is smaller than a fixed threshold (train_prompts_buffer_size), keep generating by processing the next batch
if generation_sample_buffer_size < train_prompts_buffer_size:
max_num_gen_batches = master_config["grpo"].get(
"max_num_gen_batches", 0
)
if (
max_num_gen_batches <= 0
or num_gen_batches <= max_num_gen_batches
):
continue
else:
raise ValueError(
f"Dynamic sampling has reached the maximum allowable number of batches ({max_num_gen_batches}). Consider evaluating the complexity of your data or adjusting the num_prompts_per_step or num_generations_per_prompt parameters to enhance the diversity of the samples."
)
else:
# Select the inputs that have non-zero std
repeated_batch = repeated_batch.select_indices(
keep_prompt_indices
)

# Gather the corresponding rewards
rewards = rewards[keep_prompt_indices]

# Gather the corresponding baselines(mean)
baseline = baseline[keep_prompt_indices]

# Slice the batch, rewards, baselines and std to ensure batch size is train_prompts_buffer_size
repeated_batch = repeated_batch.slice(
0, train_prompts_buffer_size
)
rewards = rewards[:train_prompts_buffer_size]
baseline = baseline[:train_prompts_buffer_size]
std = std[:train_prompts_buffer_size]

# Process rewards with custom reward function
if master_config["reward_fn"]["enabled"]:
rewards = process_rewards(
repeated_batch, rewards, master_config["reward_fn"]
)

advantages = (rewards - baseline).unsqueeze(-1)

# Normalize rewards
if master_config["grpo"]["normalize_rewards"]:
# don't sharpen the ones with no variation
zero_std_mask = std > 0
Expand Down
79 changes: 79 additions & 0 deletions nemo_rl/algorithms/reward_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TypedDict, TypeVar

import torch

from nemo_rl.distributed.batched_data_dict import BatchedDataDict

Tensor = TypeVar("Tensor", bound=torch.Tensor)


class RewardConfig(TypedDict):
"""Configuration for reward function processing.

This configuration enables custom reward shaping, currently supporting DAPO-style
penalties for responses that exceed the maximum response length threshold.
"""

enabled: bool
overlong_buffer_length: int
overlong_buffer_penalty: float
max_response_length: int


def process_rewards(
batch: BatchedDataDict, rewards: torch.Tensor, cfg: RewardConfig
) -> torch.Tensor:
"""Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476.

Nonetheless, it can be potentially extended to support any custom reward logic.
"""
if not cfg["enabled"]:
return rewards

# DAPO reward shaping requires overlong_buffer_length, overlong_buffer_penalty, and max_response_length to be set.
if (
cfg["overlong_buffer_length"] is None
or cfg["overlong_buffer_penalty"] is None
or cfg["max_response_length"] is None
):
raise ValueError(
"Reward function is enabled but only DAPO reward shaping is currently supported. Please ensure overlong_buffer_length, overlong_buffer_penalty, and max_response_length are properly configured."
)

# Get the overlong_buffer_length, overlong_buffer_penalty and max_response_length
overlong_buffer_length = cfg["overlong_buffer_length"]
overlong_buffer_penalty = cfg["overlong_buffer_penalty"]
max_response_length = cfg["max_response_length"]

# Calculate the expected response length
expected_response_length = max_response_length - overlong_buffer_length

assert len(batch["message_log"]) == len(rewards), (
"The number of messages in the batch must match the number of rewards"
)

updated_rewards = torch.zeros_like(rewards)
for i, message_log in enumerate(batch["message_log"]):
# Get the assistant response length (index 1 is the assistant response)
message_response_length = message_log[1]["token_ids"].shape[0]
# Calculate the exceed length and the corresponding reward penalty
exceed_length = message_response_length - expected_response_length
overlong_reward = min(
-exceed_length / overlong_buffer_length * overlong_buffer_penalty, 0
)
updated_rewards[i] = rewards[i] + overlong_reward

return updated_rewards