Skip to content

Commit 7bd853a

Browse files
peri044ashors1terrykongjgerh
authored
feat: Support DAPO dynamic sampling and reward shaping (#602)
Signed-off-by: Dheeraj Peri <[email protected]> Signed-off-by: ashors1 <[email protected]> Co-authored-by: ashors1 <[email protected]> Co-authored-by: Terry Kong <[email protected]> Co-authored-by: jgerh <[email protected]>
1 parent dee3fd9 commit 7bd853a

35 files changed

+1796
-39
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
[![CICD NeMo RL](https://github.com/NVIDIA-NeMo/RL/actions/workflows/cicd-main.yml/badge.svg?branch=main&event=schedule)](https://github.com/NVIDIA-NeMo/RL/actions/workflows/cicd-main.yml)
44

55
## 📣 News
6+
* [10/10/2025] **DAPO Algorithm Support**
7+
NeMo RL now supports [Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO)](https://arxiv.org/pdf/2503.14476) algorithm.
8+
DAPO extends GRPO with **Clip-Higher**, **Dynamic Sampling**, **Token-Level Policy Gradient Loss**, and **Overlong Reward Shaping** for more stable and efficient RL training. See the [DAPO guide](docs/guides/dapo.md) for more details.
69
* [9/30/2025][Accelerated RL on GCP with NeMo RL!](https://discuss.google.dev/t/accelerating-reinforcement-learning-on-google-cloud-using-nvidia-nemo-rl/269579/4)
710
* [9/27/2025] [FP8 Quantization in NeMo RL](https://github.com/NVIDIA-NeMo/RL/discussions/1216)
811
* [9/25/2025] On-policy Distillation (Qwen3-style)

docs/assets/dapo_train_reward.png

30.7 KB
Loading

docs/assets/dapo_val_acc.png

23.7 KB
Loading

docs/guides/dapo.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# An in-depth Walkthrough of DAPO in NeMo RL
2+
3+
This guide covers the [Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO)](https://arxiv.org/pdf/2503.14476) implementation in NeMo RL.
4+
5+
DAPO introduces four key improvements over Group Relative Policy Optimization (GRPO):
6+
1. **Clip-Higher**, which promotes the diversity of the system and avoids entropy collapse
7+
2. **Dynamic Sampling**, which improves training efficiency and stability
8+
3. **Token-Level Policy Gradient Loss**, which is critical in long-CoT RL scenarios
9+
4. **Overlong Reward Shaping**, which reduces reward noise and stabilizes training
10+
11+
This document focuses on DAPO-specific features: Dynamic Sampling and Overlong Reward Shaping. For foundational concepts on GRPO including data handling, policy training, generation, and loss functions, see the [NeMo RL GRPO Guide](grpo.md).
12+
13+
14+
## Quickstart: Launch a DAPO Run
15+
16+
To get started quickly, use the example configuration [examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml](../../examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml). You can launch this using the same script as GRPO:
17+
18+
```bash
19+
uv run examples/run_grpo_math.py --config examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml {overrides}
20+
```
21+
22+
**Reminder**: Don't forget to set your HF_HOME, WANDB_API_KEY, and HF_DATASETS_CACHE (if needed). You'll need to do a `huggingface-cli login` as well for LLaMA models.
23+
24+
## Dynamic Sampling
25+
26+
Standard GRPO trains on all generated responses, even when they have identical rewards (zero gradient signal) within a prompt group of generations. Dynamic sampling filters to keep only groups with diverse rewards (`std > 0`), and accumulates them across batches until reaching the target batch size. Dynamic sampling can be enabled by setting `use_dynamic_sampling=True` in your configuration. For implementation details, see the [`dynamic_sampling`](../../nemo_rl/algorithms/grpo.py) function.
27+
28+
**Algorithm**: For each training step:
29+
30+
1. Sample `batch_multiplier × num_prompts_per_step` prompts from the dataset. The default value of `batch_multiplier` is 1.
31+
2. Generate `num_generations_per_prompt` responses per prompt and compute rewards.
32+
3. Compute the baseline and standard deviation for each prompt group.
33+
4. Filter prompt groups where `std > 0`.
34+
5. Store these prompts in a cache until reaching the target training batch size of `num_prompts_per_step × num_generations_per_prompt` samples.
35+
6. Samples are accumulated until the maximum number of allowed batches (`dynamic_sampling_max_gen_batches`) is reached. If the cache still does not meet the target rollout batch size at that point, an error is raised. To resolve this, consider adjusting parameters such as `num_prompts_per_step` or `num_generations_per_prompt` to increase sample diversity, or revisit the complexity of your data.
36+
7. Perform training on the collected samples with nonzero standard deviation
37+
38+
### About batch_multiplier
39+
40+
`batch_multiplier` (a float ≥ 1.0) controls the initial prompt pool size by sampling `batch_multiplier × num_prompts_per_step` prompts before dynamic sampling. Higher values increase memory and compute requirements, while very low values (e.g., 1.0) may slow the cache accumulation of prompt groups with nonzero standard deviation. The optimal value depends on the dataset, model capacity, and overall training setup. When **dynamic sampling** is enabled, we also log two additional metrics:
41+
42+
* `dynamic_sampling_num_gen_batches`: The number of generation rounds required to produce `num_prompts_per_step * num_generations_per_prompt` samples with a nonzero standard deviation. If this number remains consistently high across iterations, try increasing the `batch_multiplier`. The maximum allowed value for this parameter is determined by `dynamic_sampling_max_gen_batches`.
43+
* `dynamic_sampling_num_discarded_valid_samples`: The number of samples with a nonzero standard deviation that are discarded because the total exceeds `num_prompts_per_step * num_generations_per_prompt`. If this value is frequently high (e.g., above `0.5 * num_prompts_per_step * num_generations_per_prompt`) and `dynamic_sampling_num_gen_batches` is consistently 1, it suggests that a large fraction of the dataset is being discarded unnecessarily. To improve data efficiency, consider decreasing the `batch_multiplier`.
44+
45+
## Reward Shaping
46+
DAPO introduces an overlong reward shaping mechanism to reduce reward noise and stabilize training. This approach penalizes responses that exceed a specified length threshold, helping to prevent the model from generating excessively long outputs while maintaining solution quality.
47+
48+
For a detailed explanation of the overlong reward shaping mechanism, please refer to Section 3.4 of the [DAPO paper](https://arxiv.org/pdf/2503.14476). For implementation details, see the [`apply_reward_shaping`](../../nemo_rl/algorithms/reward_functions.py) function.
49+
50+
## Configuration
51+
52+
```yaml
53+
grpo:
54+
use_dynamic_sampling: true # Enable DAPO dynamic sampling
55+
num_prompts_per_step: 512 # Target number of prompts per training step
56+
num_generations_per_prompt: 16 # Generations per prompt
57+
batch_multiplier: 3 # Dataloader batch size = batch_multiplier × num_prompts_per_step
58+
dynamic_sampling_max_gen_batches: 10 # Maximum number of batches to be used for accumulating non-zero std prompts
59+
reward_scaling:
60+
enabled: true
61+
source_min: 0.0
62+
source_max: 1.0
63+
target_min: -1.0
64+
target_max: 1.0
65+
66+
reward_shaping:
67+
enabled: true
68+
overlong_buffer_length: 4096 # Threshold before penalties apply (paper uses 4096)
69+
overlong_buffer_penalty: 1.0 # Penalty per excess token
70+
max_response_length: 20480 # Hard maximum generation length
71+
```
72+
73+
**Key Parameters:**
74+
- **`use_dynamic_sampling`**: When enabled, activates DAPO's dynamic sampling algorithm to filter and accumulate prompt groups with nonzero standard deviation
75+
- **`batch_multiplier`**: Factor that scales the initial prompt pool size for sampling.
76+
- **`dynamic_sampling_max_gen_batches`**: Maximum number of batches to be used for accumulating nonzero standard deviation prompts.
77+
- **`reward_scaling`**: When enabled, clamps each reward in the batch to [source_min, source_max] and linearly rescales it to [target_min, target_max]. Defaults: source_min=0.0, source_max=1.0, target_min=0.0, target_max=1.0.
78+
- **`reward_shaping`**: When enabled, applies the overlong penalty mechanism described in the Reward Shaping section above. Responses exceeding `max_response_length - overlong_buffer_length` receive penalties proportional to their excess length, helping to reduce reward noise and stabilize training.
79+
80+
> [!NOTE]
81+
> When dynamic sampling is enabled, monitor the `filtered_reward` metric to track the average reward of the prompts with std > 0.
82+
83+
> [!NOTE]
84+
> **Clip-Higher** and **Token-Level Policy Gradient Loss** are already supported in NeMo RL and can be configured through the `loss_fn` section of your experiment config:
85+
> - Set `ratio_clip_max` to enable Clip-Higher (e.g., `ratio_clip_max: 0.28`)
86+
> - Set `token_level_loss: true` to enable Token-Level Policy Gradient Loss
87+
>
88+
> See the full [DAPO example config](../../examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml) for reference.
89+
90+
## Example Training Results
91+
Using the [DAPO example config](../../examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml), you can expect to see intermediate plots such as the training reward curve and validation accuracy on AIME24 for Qwen/Qwen2.5-Math-7B. These plots serve as reference outputs to help verify reproducibility. They are not intended to reflect the best accuracy that can be achieved using DAPO for this model.
92+
93+
![DAPO Qwen2.5-7B Training Reward](../assets/dapo_train_reward.png)
94+
![DAPO Qwen2.5-7B Validation Accuracy](../assets/dapo_val_acc.png)
95+
96+
## References
97+
98+
- **DAPO Paper**: [Decoupled Clip and Dynamic Sampling Policy Optimization](https://arxiv.org/pdf/2503.14476)
99+
- **GRPO Paper**: [Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300)
100+
- **[NeMo RL GRPO Guide](grpo.md)**

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ guides/sft-openmathinstruct2.md
2626
adding-new-models.md
2727
guides/sft.md
2828
guides/dpo.md
29+
guides/dapo.md
2930
guides/grpo.md
3031
guides/grpo-deepscaler.md
3132
guides/grpo-sliding-puzzle.md

examples/configs/grpo_math_1B.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@ grpo:
1313
max_val_samples: 256
1414
val_batch_size: 256
1515
seed: 42
16+
use_dynamic_sampling: false
17+
dynamic_sampling_max_gen_batches: 10
18+
batch_multiplier: 1
19+
reward_shaping:
20+
enabled: false
21+
overlong_buffer_length: 128
22+
overlong_buffer_penalty: 1
23+
max_response_length: ${policy.max_total_sequence_length}
24+
reward_scaling:
25+
enabled: false
26+
source_min: 0.0
27+
source_max: 1.0
28+
target_min: 0.0
29+
target_max: 1.0
30+
1631
async_grpo:
1732
enabled: false # Set to true to enable async training mode
1833
# Max age (in training steps) for trajectories used in training
@@ -47,6 +62,7 @@ policy:
4762
tokenizer:
4863
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
4964
chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true
65+
hf_config_overrides: null
5066
train_global_batch_size: 512
5167
train_micro_batch_size: 4
5268
generation_batch_size: 32 # Only used when generating using HF backend
@@ -237,6 +253,11 @@ data:
237253
env:
238254
math:
239255
num_workers: 8
256+
math_verify_impl: "hf_math_verify"
257+
## unused in this config but needed for DAPO recipe
258+
dapo:
259+
num_workers: 8
260+
math_verify_impl: "dapo_math_verify"
240261

241262
logger:
242263
log_dir: "logs" # Base directory for all logs
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
num_prompts_per_step: 512
4+
num_generations_per_prompt: 16
5+
batch_multiplier: 3 # Multiplier for dataloader batch size calculation (batch_multiplier × num_prompts_per_step). Following DAPO dynamic sampling, the actual training batch size equals num_prompts_per_step × num_generations_per_prompt.
6+
max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question)
7+
max_num_steps: 10000
8+
use_leave_one_out_baseline: false
9+
val_period: 20
10+
max_val_samples: 960
11+
val_batch_size: 960
12+
use_dynamic_sampling: true
13+
dynamic_sampling_max_gen_batches: 10
14+
reward_scaling:
15+
enabled: true
16+
source_min: 0.0
17+
source_max: 1.0
18+
target_min: -1.0
19+
target_max: 1.0
20+
reward_shaping:
21+
enabled: true
22+
overlong_buffer_length: 2048
23+
max_response_length: 14336
24+
loss_fn:
25+
reference_policy_kl_penalty: 0.0
26+
ratio_clip_max: 0.28
27+
ratio_clip_c: 10
28+
checkpointing:
29+
checkpoint_dir: results/dapo-qwen2.5-7b
30+
keep_top_k: 5
31+
save_period: 5
32+
model_save_format: "dcp"
33+
policy:
34+
model_name: Qwen/Qwen2.5-Math-7B
35+
hf_config_overrides:
36+
max_position_embeddings: 16384
37+
train_micro_batch_size: 1
38+
logprob_batch_size: 1
39+
max_total_sequence_length: 16384
40+
dtensor_cfg:
41+
_v2: false
42+
context_parallel_size: 4
43+
megatron_cfg:
44+
empty_unused_memory_level: 1
45+
tensor_model_parallel_size: 4
46+
pipeline_model_parallel_size: 2
47+
context_parallel_size: 2
48+
sequence_parallel: true
49+
optimizer:
50+
lr: 1.0e-06
51+
min_lr: 1.0e-06
52+
weight_decay: 0.1
53+
scheduler:
54+
lr_decay_iters: null
55+
lr_warmup_iters: 10
56+
lr_warmup_init: 1.0e-07
57+
sequence_packing:
58+
enabled: false
59+
make_sequence_length_divisible_by: ${mul:${policy.dtensor_cfg.tensor_parallel_size},
60+
${mul:2, ${policy.dtensor_cfg.context_parallel_size}}}
61+
optimizer:
62+
kwargs:
63+
lr: 1.0e-06
64+
weight_decay: 0.1
65+
scheduler:
66+
- name: torch.optim.lr_scheduler.LinearLR
67+
kwargs:
68+
start_factor: 1.0e-08
69+
end_factor: 1.0
70+
total_iters: 10
71+
- name: torch.optim.lr_scheduler.ConstantLR
72+
kwargs:
73+
factor: 1.0
74+
total_iters: 10000000000
75+
- milestones:
76+
- 10
77+
generation:
78+
max_new_tokens: 16384
79+
vllm_cfg:
80+
tensor_parallel_size: 2
81+
gpu_memory_utilization: 0.7
82+
enforce_eager: true
83+
data:
84+
max_input_seq_length: 2048
85+
prompt_file: null
86+
dataset_name: DAPOMath17K
87+
env:
88+
dapo:
89+
num_workers: 16
90+
math:
91+
num_workers: 16
92+
math_verify_impl: "dapo_math_verify"
93+
94+
logger:
95+
monitor_gpus: false
96+
wandb:
97+
project: dapo-dev
98+
name: dapo-dev-logger
99+
mlflow:
100+
experiment_name: dapo-dev
101+
run_name: dapo-dev-logger
102+
cluster:
103+
gpus_per_node: 8
104+
num_nodes: 16

examples/configs/vlm_grpo_3B.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@ grpo:
1414
max_val_samples: 256
1515
val_batch_size: 256
1616
seed: 42
17+
use_dynamic_sampling: false
18+
batch_multiplier: 1
19+
reward_shaping:
20+
enabled: false
21+
overlong_buffer_length: 512
22+
overlong_buffer_penalty: 1
23+
max_response_length: ${policy.max_total_sequence_length}
24+
reward_scaling:
25+
enabled: false
26+
source_min: 0.0
27+
source_max: 1.0
28+
target_min: 0.0
29+
target_max: 1.0
1730
async_grpo:
1831
enabled: false
1932
max_trajectory_age_steps: 1

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@ grpo:
1212
max_val_samples: 256
1313
val_batch_size: 256
1414
seed: 42
15+
use_dynamic_sampling: false
16+
batch_multiplier: 1
17+
reward_shaping:
18+
enabled: false
19+
overlong_buffer_length: 512
20+
overlong_buffer_penalty: 1
21+
max_response_length: ${policy.max_total_sequence_length}
22+
reward_scaling:
23+
enabled: false
24+
source_min: 0.0
25+
source_max: 1.0
26+
target_min: 0.0
27+
target_max: 1.0
1528
async_grpo:
1629
enabled: false
1730
max_trajectory_age_steps: 1

examples/converters/convert_dcp_to_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,14 @@ def main():
5757
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
5858
# change release to release.
5959
tokenizer_name_or_path = config["policy"]["model_name"]
60+
hf_overrides = config["policy"].get("hf_overrides", {}) or {}
6061

6162
hf_ckpt = convert_dcp_to_hf(
6263
dcp_ckpt_path=args.dcp_ckpt_path,
6364
hf_ckpt_path=args.hf_ckpt_path,
6465
model_name_or_path=model_name_or_path,
6566
tokenizer_name_or_path=tokenizer_name_or_path,
67+
hf_overrides=hf_overrides,
6668
)
6769
print(f"Saved HF checkpoint to: {hf_ckpt}")
6870

0 commit comments

Comments
 (0)