-
Notifications
You must be signed in to change notification settings - Fork 48
Add dapo reward #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add dapo reward #114
Changes from 11 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
b8c262a
add dapo workflow and reward
hiyuchang b9076b7
fix bug
hiyuchang da57574
Merge branch 'main' into feat/len_punish
hiyuchang 889b36c
add base model mode for dapo_workflow
hiyuchang a28c310
Merge branch 'main' into feat/len_punish
hiyuchang 880c5b0
fix dapo config
hiyuchang fd39fe4
Merge branch 'main' into feat/len_punish
hiyuchang c314a11
fix typo
hiyuchang e1c7622
fix comments
hiyuchang 7621887
mv dapo workflow to boxed workflo
hiyuchang ca04b7a
fix unittest
hiyuchang 7d96213
fix typo
hiyuchang 3239c73
fix typo
hiyuchang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # DAPO on DAPO-MATH-17k dataset [WIP] | ||
|
|
||
| This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. | ||
|
|
||
| The config files are located in [`dapo.yaml`](dapo.yaml) and [`train_dapo.yaml`](train_dapo.yaml). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| project: Trinity-RFT-example | ||
| name: dapo | ||
| checkpoint_root_dir: /PATH/TO/CHECKPOINT/ | ||
| model: | ||
| model_path: /PATH/TO/MODEL/ | ||
| algorithm: | ||
| algorithm_type: grpo | ||
| repeat_times: 16 | ||
| policy_loss_fn_args: | ||
| clip_range_low: 0.2 | ||
| clip_range_high: 0.28 | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| total_epochs: 1 | ||
| batch_size: 32 | ||
| max_retry_times: 3 | ||
| max_retry_interval: 1 | ||
| explorer_input: | ||
| taskset: | ||
| name: dapo-math | ||
| storage_type: file | ||
| path: open-r1/DAPO-Math-17k-Processed | ||
| subset_name: all | ||
| format: | ||
| prompt_key: 'prompt' | ||
| response_key: 'solution' | ||
| system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.' | ||
| rollout_args: | ||
| temperature: 1.0 | ||
| logprobs: 0 | ||
| workflow_args: | ||
| use_base: true | ||
| reward_fn_args: | ||
| enable_overlong_penalty: true | ||
| penalty_factor: 1.0 | ||
| max_response_length: 20480 | ||
| cache_length: 4096 | ||
| eval_tasksets: | ||
| - name: AIME2024 | ||
| storage_type: file | ||
| path: /PATH/TO/AIME2024/ | ||
| split: 'test' | ||
| format: | ||
| prompt_key: 'question' | ||
| response_key: 'answer' | ||
| system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.' | ||
| rollout_args: | ||
| n: 32 | ||
| temperature: 1.0 | ||
| top_p: 0.7 | ||
| default_workflow_type: 'math_boxed_workflow' | ||
| default_reward_fn_type: 'math_dapo_reward' | ||
| trainer_input: | ||
| experience_buffer: | ||
| name: math_buffer | ||
| storage_type: queue | ||
| explorer: | ||
| eval_interval: 10 | ||
| runner_num: 32 | ||
| rollout_model: | ||
| engine_type: vllm_async | ||
| engine_num: 4 | ||
| tensor_parallel_size: 1 | ||
| enable_prefix_caching: false | ||
| enforce_eager: true | ||
| dtype: bfloat16 | ||
| max_prompt_tokens: 1024 | ||
| max_response_tokens: 20480 | ||
| seed: 42 | ||
| synchronizer: | ||
| sync_method: 'nccl' | ||
| sync_interval: 16 | ||
| sync_timeout: 1200 | ||
| trainer: | ||
| trainer_type: 'verl' | ||
| trainer_config_path: 'examples/dapo_math/train_dapo.yaml' | ||
| save_interval: 100 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| actor_rollout_ref: | ||
| hybrid_engine: True | ||
| model: | ||
| external_lib: null | ||
| override_config: { } | ||
| enable_gradient_checkpointing: True | ||
| use_remove_padding: True # False | ||
| actor: | ||
| strategy: fsdp # This is for backward-compatibility | ||
| ppo_micro_batch_size_per_gpu: 4 | ||
| use_dynamic_bsz: True # False | ||
| ppo_max_token_len_per_gpu: 22000 # n * ${data.max_prompt_length} + ${data.max_response_length} | ||
| grad_clip: 1.0 | ||
| ppo_epochs: 1 | ||
| shuffle: False | ||
| ulysses_sequence_parallel_size: 1 # sp size | ||
| optim: | ||
| lr: 1e-6 | ||
| lr_warmup_steps: 20 # the total steps will be injected during runtime | ||
| # min_lr_ratio: null # only useful for warmup with cosine | ||
| warmup_style: constant # select from constant/cosine | ||
| total_training_steps: -1 # must be override by program | ||
| fsdp_config: | ||
| wrap_policy: | ||
| # transformer_layer_cls_to_wrap: None | ||
| min_num_params: 0 | ||
| param_offload: False | ||
| optimizer_offload: False | ||
| fsdp_size: -1 | ||
| ref: | ||
| fsdp_config: | ||
| param_offload: False | ||
| wrap_policy: | ||
| # transformer_layer_cls_to_wrap: None | ||
| min_num_params: 0 | ||
| log_prob_micro_batch_size_per_gpu: 16 | ||
| log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
| log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} | ||
| ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size | ||
|
|
||
| trainer: | ||
| balance_batch: True | ||
| # auto: find the last ckpt to resume. If can't find, start from scratch | ||
| resume_mode: auto # or auto or resume_path if | ||
| default_hdfs_dir: null | ||
| remove_previous_ckpt_in_save: False | ||
| del_local_ckpt_after_load: False | ||
| val_before_train: False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)""" | ||
| from typing import Optional, Union | ||
|
|
||
| import torch | ||
|
|
||
| from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn | ||
| from trinity.utils.eval_utils import compute_score | ||
| from trinity.utils.log import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| @REWARD_FUNCTIONS.register_module("math_dapo_reward") | ||
| class MathDAPORewardFn(RewardFn): | ||
| """A reward function that follows the definition in DAPO for math task.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| enable_overlong_penalty: Optional[bool] = None, | ||
| penalty_factor: Optional[float] = None, | ||
| max_response_length: Optional[int] = None, | ||
| cache_length: Optional[int] = None, | ||
| ) -> None: | ||
| self.enable_overlong_penalty = enable_overlong_penalty | ||
| self.penalty_factor = penalty_factor | ||
| self.max_response_length = max_response_length | ||
| self.cache_length = cache_length | ||
|
|
||
| def __call__( # type: ignore | ||
| self, | ||
| response: str, | ||
| response_token: torch.Tensor, | ||
| truth: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> Union[float, dict]: | ||
hiyuchang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| accuracy_score = compute_score(response, truth) | ||
|
|
||
| format_score = 0.0 | ||
|
|
||
| if self.enable_overlong_penalty: | ||
| format_score = self.compute_overlong_penalty(response_token) | ||
|
|
||
| return { | ||
| "accuracy": accuracy_score, | ||
| "format_score": format_score, | ||
| } | ||
|
|
||
| def compute_overlong_penalty(self, response_token): | ||
| assert ( | ||
| self.max_response_length is not None | ||
| and self.cache_length is not None | ||
| and self.penalty_factor is not None | ||
| ), "When enable_overlong_penalty = true, max_response_length, penalty_factor, cache_length must be set" | ||
| assert ( | ||
| self.max_response_length > self.cache_length | ||
| ), "max_response_length must be greater than cache_length" | ||
|
|
||
| response_len = len(response_token) | ||
| excepted_len = self.max_response_length - self.cache_length | ||
|
|
||
| if response_len < excepted_len: | ||
| return 0.0 | ||
| elif response_len > self.max_response_length: | ||
| return -self.penalty_factor | ||
| else: | ||
| return (excepted_len - response_len) / self.cache_length * self.penalty_factor | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.