Skip to content

Commit a7c142f

Browse files
authored
Add dapo reward (#114)
1 parent b5ce7fa commit a7c142f

File tree

9 files changed

+223
-7
lines changed

9 files changed

+223
-7
lines changed

examples/dapo_math/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# DAPO on DAPO-MATH-17k dataset [WIP]
2+
3+
This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset.
4+
5+
The config files are located in [`dapo.yaml`](dapo.yaml) and [`train_dapo.yaml`](train_dapo.yaml).

examples/dapo_math/dapo.yaml

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
project: Trinity-RFT-example
2+
name: dapo
3+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
4+
model:
5+
model_path: /PATH/TO/MODEL/
6+
algorithm:
7+
algorithm_type: grpo
8+
repeat_times: 16
9+
policy_loss_fn_args:
10+
clip_range_low: 0.2
11+
clip_range_high: 0.28
12+
cluster:
13+
node_num: 1
14+
gpu_per_node: 8
15+
buffer:
16+
total_epochs: 1
17+
batch_size: 32
18+
max_retry_times: 3
19+
max_retry_interval: 1
20+
explorer_input:
21+
taskset:
22+
name: dapo-math
23+
storage_type: file
24+
path: open-r1/DAPO-Math-17k-Processed
25+
subset_name: all
26+
format:
27+
prompt_key: 'prompt'
28+
response_key: 'solution'
29+
system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.'
30+
rollout_args:
31+
temperature: 1.0
32+
logprobs: 0
33+
workflow_args:
34+
use_base: true
35+
reward_fn_args:
36+
enable_overlong_penalty: true
37+
penalty_factor: 1.0
38+
max_response_length: 20480
39+
cache_length: 4096
40+
eval_tasksets:
41+
- name: AIME2024
42+
storage_type: file
43+
path: /PATH/TO/AIME2024/
44+
split: 'test'
45+
format:
46+
prompt_key: 'question'
47+
response_key: 'answer'
48+
system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.'
49+
rollout_args:
50+
n: 32
51+
temperature: 1.0
52+
top_p: 0.7
53+
default_workflow_type: 'math_boxed_workflow'
54+
default_reward_fn_type: 'math_dapo_reward'
55+
trainer_input:
56+
experience_buffer:
57+
name: math_buffer
58+
storage_type: queue
59+
explorer:
60+
eval_interval: 10
61+
runner_num: 32
62+
rollout_model:
63+
engine_type: vllm_async
64+
engine_num: 4
65+
tensor_parallel_size: 1
66+
enable_prefix_caching: false
67+
enforce_eager: true
68+
dtype: bfloat16
69+
max_prompt_tokens: 1024
70+
max_response_tokens: 20480
71+
seed: 42
72+
synchronizer:
73+
sync_method: 'nccl'
74+
sync_interval: 16
75+
sync_timeout: 1200
76+
trainer:
77+
trainer_type: 'verl'
78+
trainer_config_path: 'examples/dapo_math/train_dapo.yaml'
79+
save_interval: 100

examples/dapo_math/train_dapo.yaml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
actor_rollout_ref:
2+
hybrid_engine: True
3+
model:
4+
external_lib: null
5+
override_config: { }
6+
enable_gradient_checkpointing: True
7+
use_remove_padding: True # False
8+
actor:
9+
strategy: fsdp # This is for backward-compatibility
10+
ppo_micro_batch_size_per_gpu: 4
11+
use_dynamic_bsz: True # False
12+
ppo_max_token_len_per_gpu: 22000 # n * ${data.max_prompt_length} + ${data.max_response_length}
13+
grad_clip: 1.0
14+
ppo_epochs: 1
15+
shuffle: False
16+
ulysses_sequence_parallel_size: 1 # sp size
17+
optim:
18+
lr: 1e-6
19+
lr_warmup_steps: 20 # the total steps will be injected during runtime
20+
# min_lr_ratio: null # only useful for warmup with cosine
21+
warmup_style: constant # select from constant/cosine
22+
total_training_steps: -1 # must be override by program
23+
fsdp_config:
24+
wrap_policy:
25+
# transformer_layer_cls_to_wrap: None
26+
min_num_params: 0
27+
param_offload: False
28+
optimizer_offload: False
29+
fsdp_size: -1
30+
ref:
31+
fsdp_config:
32+
param_offload: False
33+
wrap_policy:
34+
# transformer_layer_cls_to_wrap: None
35+
min_num_params: 0
36+
log_prob_micro_batch_size_per_gpu: 16
37+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
38+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
39+
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
40+
41+
trainer:
42+
balance_batch: True
43+
# auto: find the last ckpt to resume. If can't find, start from scratch
44+
resume_mode: auto # or auto or resume_path if
45+
default_hdfs_dir: null
46+
remove_previous_ckpt_in_save: False
47+
del_local_ckpt_after_load: False
48+
val_before_train: False

tests/explorer/workflow_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Dict, Optional
66
from unittest.mock import MagicMock
77

8+
from torch import Tensor
9+
810
from tests.tools import get_unittest_dataset_config
911
from trinity.common.rewards import RMGalleryFn
1012
from trinity.common.workflows import (
@@ -23,6 +25,8 @@ class MockResponse:
2325
metrics: Optional[Dict[str, float]] = None
2426
info: Optional[Dict] = None
2527
unique_id: Optional[str] = "0"
28+
tokens: Optional[Tensor] = Tensor([0, 0])
29+
prompt_length: int = 1
2630

2731

2832
class DummyWorkflow(Workflow):

trinity/common/rewards/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .accuracy_reward import AccuracyReward
88
from .countdown_reward import CountDownRewardFn
9+
from .dapo_reward import MathDAPORewardFn
910
from .format_reward import FormatReward
1011
from .math_reward import MathBoxedRewardFn, MathRewardFn
1112

@@ -20,4 +21,5 @@
2021
"FormatReward",
2122
"MathRewardFn",
2223
"MathBoxedRewardFn",
24+
"MathDAPORewardFn",
2325
]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# -*- coding: utf-8 -*-
2+
"""Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)"""
3+
from typing import Optional
4+
5+
import torch
6+
7+
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
8+
from trinity.utils.eval_utils import compute_score
9+
from trinity.utils.log import get_logger
10+
11+
logger = get_logger(__name__)
12+
13+
14+
@REWARD_FUNCTIONS.register_module("math_dapo_reward")
15+
class MathDAPORewardFn(RewardFn):
16+
"""A reward function that follows the definition in DAPO for math task."""
17+
18+
def __init__(
19+
self,
20+
enable_overlong_penalty: Optional[bool] = None,
21+
penalty_factor: Optional[float] = None,
22+
max_response_length: Optional[int] = None,
23+
cache_length: Optional[int] = None,
24+
) -> None:
25+
self.enable_overlong_penalty = enable_overlong_penalty
26+
self.penalty_factor = penalty_factor
27+
self.max_response_length = max_response_length
28+
self.cache_length = cache_length
29+
30+
def __call__( # type: ignore
31+
self,
32+
response: str,
33+
response_token: torch.Tensor,
34+
truth: Optional[str] = None,
35+
**kwargs,
36+
) -> dict[str, float]:
37+
accuracy_score = compute_score(response, truth)
38+
39+
format_score = 0.0
40+
41+
if self.enable_overlong_penalty:
42+
format_score = self.compute_overlong_penalty(response_token)
43+
44+
return {
45+
"accuracy": accuracy_score,
46+
"format_score": format_score,
47+
}
48+
49+
def compute_overlong_penalty(self, response_token):
50+
assert (
51+
self.max_response_length is not None
52+
and self.cache_length is not None
53+
and self.penalty_factor is not None
54+
), "When enable_overlong_penalty = true, max_response_length, penalty_factor, cache_length must be set"
55+
assert (
56+
self.max_response_length > self.cache_length
57+
), "max_response_length must be greater than cache_length"
58+
59+
response_len = len(response_token)
60+
excepted_len = self.max_response_length - self.cache_length
61+
62+
if response_len < excepted_len:
63+
return 0.0
64+
elif response_len > self.max_response_length:
65+
return -self.penalty_factor
66+
else:
67+
return (excepted_len - response_len) / self.cache_length * self.penalty_factor

trinity/common/rewards/math_reward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,17 @@ class MathBoxedRewardFn(RewardFn):
4949

5050
def __init__(
5151
self,
52+
**kwargs,
5253
) -> None:
5354
pass
5455

5556
def __call__( # type: ignore
5657
self,
5758
response: str,
58-
prompt: Optional[str] = None,
5959
truth: Optional[str] = None,
6060
with_think: Optional[bool] = False,
6161
format_score_coef: Optional[float] = 0.1,
62+
**kwargs,
6263
) -> dict[str, float]:
6364
accuracy_score = compute_score(response, truth)
6465

trinity/common/workflows/customized_math_workflows.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def reset(self, task: Task):
3131
self.is_eval = task.is_eval
3232

3333
self.workflow_args = task.workflow_args
34+
self.reward_fn_args = task.reward_fn_args
3435

3536
self.use_base = self.workflow_args.get("use_base", False)
3637
self.with_think = self.workflow_args.get("with_think", False)
@@ -48,7 +49,10 @@ def reset(self, task: Task):
4849
else:
4950
self.system_prompt = default_prompt
5051

51-
self.reward_fn = MathBoxedRewardFn()
52+
if task.reward_fn is None:
53+
self.reward_fn = MathBoxedRewardFn(**self.reward_fn_args)
54+
else:
55+
self.reward_fn = task.reward_fn(**self.reward_fn_args)
5256

5357
def format_prompt(self):
5458
prompt_text = ""
@@ -60,7 +64,6 @@ def format_prompt(self):
6064
return prompt_text
6165

6266
def run(self) -> List[Experience]:
63-
# TODO: Optimize the generate function
6467
if not self.use_base:
6568
messages = self.format_messages()
6669
else:
@@ -73,11 +76,12 @@ def run(self) -> List[Experience]:
7376
responses = self.model.generate([prompt_text], **self.rollout_args)
7477

7578
for response in responses:
76-
reward_dict = MathBoxedRewardFn()( # type: ignore [misc]
79+
reward_dict = self.reward_fn( # type: ignore [misc]
7780
response=response.response_text, # type: ignore [arg-type]
7881
truth=self.truth,
7982
with_think=self.with_think,
8083
format_score_coef=self.format_score_coef,
84+
response_token=response.tokens[response.prompt_length :],
8185
)
8286

8387
if response.metrics is None:
@@ -86,7 +90,12 @@ def run(self) -> List[Experience]:
8690
reward = sum(reward_dict.values())
8791
response.reward = reward
8892

89-
logger.debug(
90-
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
91-
)
93+
if not self.use_base:
94+
logger.debug(
95+
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
96+
)
97+
else:
98+
logger.debug(
99+
f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}"
100+
)
92101
return responses

trinity/common/workflows/workflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def reset(self, task: Task):
197197
self.is_eval = task.is_eval
198198

199199
def format_messages(self):
200+
"""Format messages for the instruct model."""
200201
messages = []
201202
if self.system_prompt:
202203
messages.append({"role": "system", "content": self.system_prompt})

0 commit comments

Comments
 (0)