diff --git a/examples/dapo_math/README.md b/examples/dapo_math/README.md new file mode 100644 index 0000000000..28f8f3c625 --- /dev/null +++ b/examples/dapo_math/README.md @@ -0,0 +1,5 @@ +# DAPO on DAPO-MATH-17k dataset [WIP] + +This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. + +The config files are located in [`dapo.yaml`](dapo.yaml) and [`train_dapo.yaml`](train_dapo.yaml). diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml new file mode 100644 index 0000000000..a13375500b --- /dev/null +++ b/examples/dapo_math/dapo.yaml @@ -0,0 +1,79 @@ +project: Trinity-RFT-example +name: dapo +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ +algorithm: + algorithm_type: grpo + repeat_times: 16 + policy_loss_fn_args: + clip_range_low: 0.2 + clip_range_high: 0.28 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 32 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: dapo-math + storage_type: file + path: open-r1/DAPO-Math-17k-Processed + subset_name: all + format: + prompt_key: 'prompt' + response_key: 'solution' + system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.' + rollout_args: + temperature: 1.0 + logprobs: 0 + workflow_args: + use_base: true + reward_fn_args: + enable_overlong_penalty: true + penalty_factor: 1.0 + max_response_length: 20480 + cache_length: 4096 + eval_tasksets: + - name: AIME2024 + storage_type: file + path: /PATH/TO/AIME2024/ + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.' + rollout_args: + n: 32 + temperature: 1.0 + top_p: 0.7 + default_workflow_type: 'math_boxed_workflow' + default_reward_fn_type: 'math_dapo_reward' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue +explorer: + eval_interval: 10 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + max_prompt_tokens: 1024 + max_response_tokens: 20480 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 16 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/dapo_math/train_dapo.yaml' + save_interval: 100 diff --git a/examples/dapo_math/train_dapo.yaml b/examples/dapo_math/train_dapo.yaml new file mode 100644 index 0000000000..7bba8612ab --- /dev/null +++ b/examples/dapo_math/train_dapo.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 22000 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps: 20 # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 20a8c5064c..2132a591a0 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -5,6 +5,8 @@ from typing import Dict, Optional from unittest.mock import MagicMock +from torch import Tensor + from tests.tools import get_unittest_dataset_config from trinity.common.rewards import RMGalleryFn from trinity.common.workflows import ( @@ -23,6 +25,8 @@ class MockResponse: metrics: Optional[Dict[str, float]] = None info: Optional[Dict] = None unique_id: Optional[str] = "0" + tokens: Optional[Tensor] = Tensor([0, 0]) + prompt_length: int = 1 class DummyWorkflow(Workflow): diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index 05b752dfc6..c723788908 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -6,6 +6,7 @@ from .accuracy_reward import AccuracyReward from .countdown_reward import CountDownRewardFn +from .dapo_reward import MathDAPORewardFn from .format_reward import FormatReward from .math_reward import MathBoxedRewardFn, MathRewardFn @@ -20,4 +21,5 @@ "FormatReward", "MathRewardFn", "MathBoxedRewardFn", + "MathDAPORewardFn", ] diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py new file mode 100644 index 0000000000..a527bf613a --- /dev/null +++ b/trinity/common/rewards/dapo_reward.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +"""Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)""" +from typing import Optional + +import torch + +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.utils.eval_utils import compute_score +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@REWARD_FUNCTIONS.register_module("math_dapo_reward") +class MathDAPORewardFn(RewardFn): + """A reward function that follows the definition in DAPO for math task.""" + + def __init__( + self, + enable_overlong_penalty: Optional[bool] = None, + penalty_factor: Optional[float] = None, + max_response_length: Optional[int] = None, + cache_length: Optional[int] = None, + ) -> None: + self.enable_overlong_penalty = enable_overlong_penalty + self.penalty_factor = penalty_factor + self.max_response_length = max_response_length + self.cache_length = cache_length + + def __call__( # type: ignore + self, + response: str, + response_token: torch.Tensor, + truth: Optional[str] = None, + **kwargs, + ) -> dict[str, float]: + accuracy_score = compute_score(response, truth) + + format_score = 0.0 + + if self.enable_overlong_penalty: + format_score = self.compute_overlong_penalty(response_token) + + return { + "accuracy": accuracy_score, + "format_score": format_score, + } + + def compute_overlong_penalty(self, response_token): + assert ( + self.max_response_length is not None + and self.cache_length is not None + and self.penalty_factor is not None + ), "When enable_overlong_penalty = true, max_response_length, penalty_factor, cache_length must be set" + assert ( + self.max_response_length > self.cache_length + ), "max_response_length must be greater than cache_length" + + response_len = len(response_token) + excepted_len = self.max_response_length - self.cache_length + + if response_len < excepted_len: + return 0.0 + elif response_len > self.max_response_length: + return -self.penalty_factor + else: + return (excepted_len - response_len) / self.cache_length * self.penalty_factor diff --git a/trinity/common/rewards/math_reward.py b/trinity/common/rewards/math_reward.py index 09a5cd7428..9e1c8abed6 100644 --- a/trinity/common/rewards/math_reward.py +++ b/trinity/common/rewards/math_reward.py @@ -49,16 +49,17 @@ class MathBoxedRewardFn(RewardFn): def __init__( self, + **kwargs, ) -> None: pass def __call__( # type: ignore self, response: str, - prompt: Optional[str] = None, truth: Optional[str] = None, with_think: Optional[bool] = False, format_score_coef: Optional[float] = 0.1, + **kwargs, ) -> dict[str, float]: accuracy_score = compute_score(response, truth) diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index 1e825e4e54..c2762ae43c 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -31,6 +31,7 @@ def reset(self, task: Task): self.is_eval = task.is_eval self.workflow_args = task.workflow_args + self.reward_fn_args = task.reward_fn_args self.use_base = self.workflow_args.get("use_base", False) self.with_think = self.workflow_args.get("with_think", False) @@ -48,7 +49,10 @@ def reset(self, task: Task): else: self.system_prompt = default_prompt - self.reward_fn = MathBoxedRewardFn() + if task.reward_fn is None: + self.reward_fn = MathBoxedRewardFn(**self.reward_fn_args) + else: + self.reward_fn = task.reward_fn(**self.reward_fn_args) def format_prompt(self): prompt_text = "" @@ -60,7 +64,6 @@ def format_prompt(self): return prompt_text def run(self) -> List[Experience]: - # TODO: Optimize the generate function if not self.use_base: messages = self.format_messages() else: @@ -73,11 +76,12 @@ def run(self) -> List[Experience]: responses = self.model.generate([prompt_text], **self.rollout_args) for response in responses: - reward_dict = MathBoxedRewardFn()( # type: ignore [misc] + reward_dict = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] truth=self.truth, with_think=self.with_think, format_score_coef=self.format_score_coef, + response_token=response.tokens[response.prompt_length :], ) if response.metrics is None: @@ -86,7 +90,12 @@ def run(self) -> List[Experience]: reward = sum(reward_dict.values()) response.reward = reward - logger.debug( - f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" - ) + if not self.use_base: + logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + else: + logger.debug( + f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}" + ) return responses diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 0a2483788b..e9549d9e2e 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -197,6 +197,7 @@ def reset(self, task: Task): self.is_eval = task.is_eval def format_messages(self): + """Format messages for the instruct model.""" messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt})