-
Notifications
You must be signed in to change notification settings - Fork 307
feat: Support DAPO dynamic sampling and reward shaping #602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
116 commits
Select commit
Hold shift + click to select a range
751d995
feat: Support DAPO dynamic sampling and reward shaping
peri044 e5dd193
Merge branch 'main' into dapo
peri044 d3ec46f
-s
peri044 7b9226b
feat: Add DAPO implementation
peri044 d6cd81d
chore: revert DAPO changes in GRPO files
peri044 108e1e0
chore: move DAPO impl back to GRPO
peri044 f249641
chore: remove dapo scripts
peri044 17b519c
Merge pull request #1 from peri044/dapo_merge
peri044 0d7f9bf
fix: Fix config bug
peri044 15fb6dc
chore: add timer for dynamic sampling and log filtered and total rewards
peri044 edda8d4
chore: add dapo-math dataset, address review comments
peri044 53c004f
chore: add DAPO-MATH-17k config
peri044 65df6b6
chore: update dapo-math config
peri044 25c9d8b
chore: update config
peri044 e4431ac
chore: remove the dapo deepscalar config
peri044 d99e460
chore: add reward scaling
peri044 77d08ff
chore: change start_factor to 0.1
peri044 a33fad2
chore: update dapo config
peri044 f55c9f7
chore: update gpus_per_node
peri044 57b0544
chore: rebase with main
peri044 cff451f
chore: fix prompt_file issue and disable leave_one_out baseline
peri044 f2a9d67
chore: address review comments
peri044 4f5ce08
chore: delete approach 1 of DS
peri044 dddfb03
chore: use verl's math verifier
peri044 4000965
chore: update math env to use dapo_math_verifier
peri044 f5de167
chore: disable reward scaling when dapo_math_verifier is used
peri044 ac228d2
chore: update reward calculation
peri044 14204e4
Merge branch 'main' into dapo
peri044 6981357
chore: additional updates
peri044 cd1cbdf
chore: minor updates and add megatron config
peri044 3717c9d
chore: address review comments
peri044 88716fd
chore: remove dapo-math.yaml
peri044 d4fb90f
chore: address review comments and add reward shaping fix
peri044 b59d705
chore: rebase with main
peri044 fabe873
chore: restore to original state
peri044 8a4e12d
chore: update docstring
peri044 4b1f9a5
chore: rebase with main and resolve conflicts
peri044 2b885b8
chore: fix filtered reward metric
peri044 afec799
add unit tests
ashors1 ae6b96a
add recipe and short convergence test
ashors1 21cce62
clean up config
ashors1 c8bdae8
chore: minor updates to logger data
peri044 54ec885
chore: rebase
peri044 6d0c521
chore: updates
peri044 8abeabf
chore: update doc
peri044 06544f3
chore: update docs with intermediate results
peri044 32c636c
chore: address review comments
peri044 a73c22f
address comment, remove some redundant tests
ashors1 5d6a46f
Merge pull request #2 from peri044/ashors/dapo-tests
peri044 ed06fe0
fix configs
ashors1 72ebf0b
fix test name
ashors1 b4ee6f7
add dapo to mapping
ashors1 5e6a251
chore: Fix a test case
peri044 1d156db
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 703932b
Merge branch 'main' into dapo
peri044 6b8c945
chore: use kwargs in math verifier to be consistent
peri044 564ef19
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 2efd601
Merge pull request #3 from peri044/ashors/dapo-fix-unit-tests
peri044 0a09fac
Merge branch 'main' into dapo
peri044 7832e93
chore: update license for dapo math verifier
peri044 591e3e8
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 22e78e7
dapo convergence test fixes
ashors1 fdc0b8c
Merge branch 'main' into dapo
peri044 1676108
Merge branch 'main' into dapo
peri044 65f33bf
make hf_overrides configurable
ashors1 a43177e
propagate hf overrides to vllm
ashors1 1008fda
default to None
ashors1 5256c81
minor fix
ashors1 1d8de6c
add copyright
ashors1 3967835
lint
ashors1 ef78bfb
Merge pull request #4 from peri044/ashors/dapo-fix-conv-test
ashors1 2668f4b
Merge branch 'main' into dapo
peri044 3371bbd
chore: fix linter check
peri044 d529897
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 b7bc4d7
fix hf_overrides default
ashors1 44a3e3c
fix hf_overrides during model conversion
ashors1 1ab98c1
document configs, add missing defaults
ashors1 fe7c46e
address some of Terry's comments
ashors1 234a5fb
chore: address some review comments
peri044 329fe1c
Merge pull request #5 from peri044/ashors/dapo-comments
peri044 939189b
chore: rebase
peri044 8fe6c20
chore: fix conflicts in README.md
peri044 298aad3
chore: add dapo guide link
peri044 f39a3df
chore: fix incorrect algo name
peri044 a5f3eea
chore: address review comments
peri044 36b72e0
chore: update logging of % of prompts that are being discarded
peri044 5399de8
chore: update non_zero_std_fraction metric in wandb
peri044 3effac0
chore: update doc
peri044 bf3a635
chore: update doc
peri044 521949b
chore: fix eqn format
peri044 d5063b3
chore: update scale_rewards function
peri044 c1ffc71
update logging, update docs, rename dapo_batch_multiplier, fix some f…
ashors1 cd48dc9
fix checkpointing config
ashors1 e920e09
chore: update docstring and add test cases for reward scaling
peri044 df553f1
Merge branch 'dapo' of https://github.com/peri044/RL into dapo
peri044 9d769e9
Merge branch 'main' into dapo
peri044 8d34524
chore: directly access reward_scaling_cfg
peri044 21e5493
Merge branch 'main' into dapo
peri044 6107c62
Merge branch 'main' into dapo
peri044 00fae2e
chore: rebase
peri044 49f4ec4
tighter bounds for test, minor doc improvement
ashors1 2573f9d
Update docs/guides/dapo.md
peri044 84dc622
Update docs/guides/dapo.md
peri044 b2c29ec
Update docs/guides/dapo.md
peri044 d1eb2c4
chore: address review comments and rename max_num_gen_batches and num…
peri044 e1a3ed3
chore: Add dapo guide to index.md
peri044 1229358
fix: Fix config issue
peri044 02e0546
Merge branch 'main' into dapo
peri044 a3f0c3a
Merge branch 'main' into dapo
peri044 f297185
add missing key
ashors1 c9dcdb0
fix unit test
ashors1 302dd31
fix config issue
ashors1 f25b48b
fix distillation unit test
ashors1 769f11f
fix grpo unit test
ashors1 9bb2a9c
fix default hf_config_overrides
ashors1 02a8e1e
fix remaining hf_overrides default
ashors1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import TypedDict, TypeVar | ||
|
|
||
| import torch | ||
|
|
||
| from nemo_rl.distributed.batched_data_dict import BatchedDataDict | ||
|
|
||
| Tensor = TypeVar("Tensor", bound=torch.Tensor) | ||
|
|
||
|
|
||
| class RewardConfig(TypedDict): | ||
| """Configuration for reward function processing. | ||
|
|
||
| This configuration enables custom reward shaping, currently supporting DAPO-style | ||
| penalties for responses that exceed the maximum response length threshold. | ||
| """ | ||
|
|
||
| enabled: bool | ||
| overlong_buffer_length: int | ||
| overlong_buffer_penalty: float | ||
| max_response_length: int | ||
|
|
||
|
|
||
| def process_rewards( | ||
peri044 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| batch: BatchedDataDict, rewards: torch.Tensor, cfg: RewardConfig | ||
| ) -> torch.Tensor: | ||
| """Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476. | ||
|
|
||
| Nonetheless, it can be potentially extended to support any custom reward logic. | ||
| """ | ||
| if not cfg["enabled"]: | ||
| return rewards | ||
|
|
||
| # DAPO reward shaping requires overlong_buffer_length, overlong_buffer_penalty, and max_response_length to be set. | ||
| if ( | ||
| cfg["overlong_buffer_length"] is None | ||
| or cfg["overlong_buffer_penalty"] is None | ||
| or cfg["max_response_length"] is None | ||
| ): | ||
| raise ValueError( | ||
| "Reward function is enabled but only DAPO reward shaping is currently supported. Please ensure overlong_buffer_length, overlong_buffer_penalty, and max_response_length are properly configured." | ||
| ) | ||
|
|
||
| # Get the overlong_buffer_length, overlong_buffer_penalty and max_response_length | ||
| overlong_buffer_length = cfg["overlong_buffer_length"] | ||
| overlong_buffer_penalty = cfg["overlong_buffer_penalty"] | ||
ashors1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| max_response_length = cfg["max_response_length"] | ||
|
|
||
| # Calculate the expected response length | ||
| expected_response_length = max_response_length - overlong_buffer_length | ||
|
|
||
terrykong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert len(batch["message_log"]) == len(rewards), ( | ||
| "The number of messages in the batch must match the number of rewards" | ||
| ) | ||
|
|
||
| updated_rewards = torch.zeros_like(rewards) | ||
| for i, message_log in enumerate(batch["message_log"]): | ||
| # Get the assistant response length (index 1 is the assistant response) | ||
| message_response_length = message_log[1]["token_ids"].shape[0] | ||
peri044 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Calculate the exceed length and the corresponding reward penalty | ||
| exceed_length = message_response_length - expected_response_length | ||
| overlong_reward = min( | ||
| -exceed_length / overlong_buffer_length * overlong_buffer_penalty, 0 | ||
| ) | ||
| updated_rewards[i] = rewards[i] + overlong_reward | ||
|
|
||
| return updated_rewards | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.