Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/dapo_math/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# DAPO on DAPO-MATH-17k dataset [WIP]

This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset.

The config files are located in [`dapo.yaml`](dapo.yaml) and [`train_dapo.yaml`](train_dapo.yaml).
79 changes: 79 additions & 0 deletions examples/dapo_math/dapo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
project: Trinity-RFT-example
name: dapo
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL/
algorithm:
algorithm_type: grpo
repeat_times: 16
policy_loss_fn_args:
clip_range_low: 0.2
clip_range_high: 0.28
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 32
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: dapo-math
storage_type: file
path: open-r1/DAPO-Math-17k-Processed
subset_name: all
format:
prompt_key: 'prompt'
response_key: 'solution'
system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.'
rollout_args:
temperature: 1.0
logprobs: 0
workflow_args:
use_base: true
reward_fn_args:
enable_overlong_penalty: true
penalty_factor: 1.0
max_response_length: 20480
cache_length: 4096
eval_tasksets:
- name: AIME2024
storage_type: file
path: /PATH/TO/AIME2024/
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.'
rollout_args:
n: 32
temperature: 1.0
top_p: 0.7
default_workflow_type: 'math_boxed_workflow'
default_reward_fn_type: 'math_dapo_reward'
trainer_input:
experience_buffer:
name: math_buffer
storage_type: queue
explorer:
eval_interval: 10
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
max_prompt_tokens: 1024
max_response_tokens: 20480
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 16
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/dapo_math/train_dapo.yaml'
save_interval: 100
48 changes: 48 additions & 0 deletions examples/dapo_math/train_dapo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 22000 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps: 20 # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 16
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trainer:
balance_batch: True
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
4 changes: 4 additions & 0 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Dict, Optional
from unittest.mock import MagicMock

from torch import Tensor

from tests.tools import get_unittest_dataset_config
from trinity.common.rewards import RMGalleryFn
from trinity.common.workflows import (
Expand All @@ -23,6 +25,8 @@ class MockResponse:
metrics: Optional[Dict[str, float]] = None
info: Optional[Dict] = None
unique_id: Optional[str] = "0"
tokens: Optional[Tensor] = Tensor([0, 0])
prompt_length: int = 1


class DummyWorkflow(Workflow):
Expand Down
2 changes: 2 additions & 0 deletions trinity/common/rewards/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .accuracy_reward import AccuracyReward
from .countdown_reward import CountDownRewardFn
from .dapo_reward import MathDAPORewardFn
from .format_reward import FormatReward
from .math_reward import MathBoxedRewardFn, MathRewardFn

Expand All @@ -20,4 +21,5 @@
"FormatReward",
"MathRewardFn",
"MathBoxedRewardFn",
"MathDAPORewardFn",
]
67 changes: 67 additions & 0 deletions trinity/common/rewards/dapo_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
"""Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)"""
from typing import Optional

import torch

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.eval_utils import compute_score
from trinity.utils.log import get_logger

logger = get_logger(__name__)


@REWARD_FUNCTIONS.register_module("math_dapo_reward")
class MathDAPORewardFn(RewardFn):
"""A reward function that follows the definition in DAPO for math task."""

def __init__(
self,
enable_overlong_penalty: Optional[bool] = None,
penalty_factor: Optional[float] = None,
max_response_length: Optional[int] = None,
cache_length: Optional[int] = None,
) -> None:
self.enable_overlong_penalty = enable_overlong_penalty
self.penalty_factor = penalty_factor
self.max_response_length = max_response_length
self.cache_length = cache_length

def __call__( # type: ignore
self,
response: str,
response_token: torch.Tensor,
truth: Optional[str] = None,
**kwargs,
) -> dict[str, float]:
accuracy_score = compute_score(response, truth)

format_score = 0.0

if self.enable_overlong_penalty:
format_score = self.compute_overlong_penalty(response_token)

return {
"accuracy": accuracy_score,
"format_score": format_score,
}

def compute_overlong_penalty(self, response_token):
assert (
self.max_response_length is not None
and self.cache_length is not None
and self.penalty_factor is not None
), "When enable_overlong_penalty = true, max_response_length, penalty_factor, cache_length must be set"
assert (
self.max_response_length > self.cache_length
), "max_response_length must be greater than cache_length"

response_len = len(response_token)
excepted_len = self.max_response_length - self.cache_length

if response_len < excepted_len:
return 0.0
elif response_len > self.max_response_length:
return -self.penalty_factor
else:
return (excepted_len - response_len) / self.cache_length * self.penalty_factor
3 changes: 2 additions & 1 deletion trinity/common/rewards/math_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,17 @@ class MathBoxedRewardFn(RewardFn):

def __init__(
self,
**kwargs,
) -> None:
pass

def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
with_think: Optional[bool] = False,
format_score_coef: Optional[float] = 0.1,
**kwargs,
) -> dict[str, float]:
accuracy_score = compute_score(response, truth)

Expand Down
21 changes: 15 additions & 6 deletions trinity/common/workflows/customized_math_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def reset(self, task: Task):
self.is_eval = task.is_eval

self.workflow_args = task.workflow_args
self.reward_fn_args = task.reward_fn_args

self.use_base = self.workflow_args.get("use_base", False)
self.with_think = self.workflow_args.get("with_think", False)
Expand All @@ -48,7 +49,10 @@ def reset(self, task: Task):
else:
self.system_prompt = default_prompt

self.reward_fn = MathBoxedRewardFn()
if task.reward_fn is None:
self.reward_fn = MathBoxedRewardFn(**self.reward_fn_args)
else:
self.reward_fn = task.reward_fn(**self.reward_fn_args)

def format_prompt(self):
prompt_text = ""
Expand All @@ -60,7 +64,6 @@ def format_prompt(self):
return prompt_text

def run(self) -> List[Experience]:
# TODO: Optimize the generate function
if not self.use_base:
messages = self.format_messages()
else:
Expand All @@ -73,11 +76,12 @@ def run(self) -> List[Experience]:
responses = self.model.generate([prompt_text], **self.rollout_args)

for response in responses:
reward_dict = MathBoxedRewardFn()( # type: ignore [misc]
reward_dict = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
truth=self.truth,
with_think=self.with_think,
format_score_coef=self.format_score_coef,
response_token=response.tokens[response.prompt_length :],
)

if response.metrics is None:
Expand All @@ -86,7 +90,12 @@ def run(self) -> List[Experience]:
reward = sum(reward_dict.values())
response.reward = reward

logger.debug(
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
if not self.use_base:
logger.debug(
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
else:
logger.debug(
f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}"
)
return responses
1 change: 1 addition & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def reset(self, task: Task):
self.is_eval = task.is_eval

def format_messages(self):
"""Format messages for the instruct model."""
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
Expand Down