Skip to content

Commit 7621887

Browse files
committed
mv dapo workflow to boxed workflo
1 parent e1c7622 commit 7621887

File tree

8 files changed

+28
-92
lines changed

8 files changed

+28
-92
lines changed

examples/dapo_math/dapo.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ buffer:
3232
logprobs: 0
3333
workflow_args:
3434
use_base: true
35+
reward_fn_args:
3536
enable_overlong_penalty: true
3637
penalty_factor: 1.0
3738
max_response_length: 20480
@@ -49,7 +50,8 @@ buffer:
4950
n: 32
5051
temperature: 1.0
5152
top_p: 0.7
52-
default_workflow_type: 'math_dapo_workflow'
53+
default_workflow_type: 'math_boxed_workflow'
54+
default_reward_fn_type: 'math_dapo_reward'
5355
trainer_input:
5456
experience_buffer:
5557
name: math_buffer

trinity/common/rewards/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .accuracy_reward import AccuracyReward
88
from .countdown_reward import CountDownRewardFn
9+
from .dapo_reward import MathDAPORewardFn
910
from .format_reward import FormatReward
1011
from .math_reward import MathBoxedRewardFn, MathRewardFn
1112

trinity/common/rewards/dapo_reward.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __call__( # type: ignore
3232
response: str,
3333
response_token: torch.Tensor,
3434
truth: Optional[str] = None,
35+
**kwargs,
3536
) -> Union[float, dict]:
3637
accuracy_score = compute_score(response, truth)
3738

trinity/common/rewards/math_reward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,17 @@ class MathBoxedRewardFn(RewardFn):
4949

5050
def __init__(
5151
self,
52+
**kwargs,
5253
) -> None:
5354
pass
5455

5556
def __call__( # type: ignore
5657
self,
5758
response: str,
58-
prompt: Optional[str] = None,
5959
truth: Optional[str] = None,
6060
with_think: Optional[bool] = False,
6161
format_score_coef: Optional[float] = 0.1,
62+
**kwargs,
6263
) -> dict[str, float]:
6364
accuracy_score = compute_score(response, truth)
6465

trinity/common/workflows/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""Workflow module"""
33
from .customized_math_workflows import MathBoxedWorkflow
4-
from .dapo_workflow import MathDAPOWorkflow
54
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
65
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
76
from .envs.webshop.webshop_workflow import WebShopWorkflow
@@ -19,5 +18,4 @@
1918
"SciWorldWorkflow",
2019
"MathBoxedWorkflow",
2120
"MathRMWorkflow",
22-
"MathDAPOWorkflow",
2321
]

trinity/common/workflows/customized_math_workflows.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def reset(self, task: Task):
3131
self.is_eval = task.is_eval
3232

3333
self.workflow_args = task.workflow_args
34+
self.reward_fn_args = task.reward_fn_args
3435

3536
self.use_base = self.workflow_args.get("use_base", False)
3637
self.with_think = self.workflow_args.get("with_think", False)
@@ -49,9 +50,18 @@ def reset(self, task: Task):
4950
self.system_prompt = default_prompt
5051

5152
if task.reward_fn is None:
52-
self.reward_fn = MathBoxedRewardFn()
53+
self.reward_fn = MathBoxedRewardFn(**self.reward_fn_args)
5354
else:
54-
self.reward_fn = task.reward_fn
55+
self.reward_fn = task.reward_fn(**self.reward_fn_args)
56+
57+
def format_prompt(self):
58+
prompt_text = ""
59+
if self.system_prompt:
60+
prompt_text += "System:" + self.system_prompt
61+
prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n"
62+
else:
63+
prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n"
64+
return prompt_text
5565

5666
def run(self) -> List[Experience]:
5767
if not self.use_base:
@@ -71,6 +81,7 @@ def run(self) -> List[Experience]:
7181
truth=self.truth,
7282
with_think=self.with_think,
7383
format_score_coef=self.format_score_coef,
84+
response_token=response.tokens[response.prompt_length :],
7485
)
7586

7687
if response.metrics is None:
@@ -79,7 +90,12 @@ def run(self) -> List[Experience]:
7990
reward = sum(reward_dict.values())
8091
response.reward = reward
8192

82-
logger.debug(
83-
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
84-
)
93+
if not self.use_base:
94+
logger.debug(
95+
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
96+
)
97+
else:
98+
logger.debug(
99+
f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}"
100+
)
85101
return responses

trinity/common/workflows/dapo_workflow.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

trinity/common/workflows/workflow.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,6 @@ def format_messages(self):
206206
messages.append({"role": "assistant", "content": self.reply_prefix})
207207
return messages
208208

209-
def format_prompt(self):
210-
"""Format prompt for the base model."""
211-
prompt_text = ""
212-
if self.system_prompt:
213-
prompt_text += "System:\n" + self.system_prompt
214-
prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n"
215-
else:
216-
prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n"
217-
return prompt_text
218-
219209
def run(self) -> List[Experience]:
220210
# TODO: Optimize the generate function
221211
messages = self.format_messages()

0 commit comments

Comments
 (0)