Skip to content

Commit ffd246e

Browse files
authored
Customized math workflows (#88)
1 parent 079201b commit ffd246e

File tree

5 files changed

+411
-1
lines changed

5 files changed

+411
-1
lines changed

tests/explorer/workflow_test.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest.mock import MagicMock
66

77
from tests.tools import get_unittest_dataset_config
8-
from trinity.common.workflows import MathWorkflow, Workflow
8+
from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow
99
from trinity.common.workflows.workflow import Task
1010

1111

@@ -134,6 +134,57 @@ def test_math_complex_workflow(self) -> None:
134134
self.assertEqual(len(experiences), 1)
135135
self.assertEqual(experiences[0].reward, 0.9)
136136

137+
def test_math_boxed_workflow(self) -> None:
138+
model = MagicMock()
139+
model.chat.return_value = [
140+
MockResponse("<think> balabalabala 99 </think>\n \\boxed{36}"),
141+
MockResponse("answer is \\boxed{36 }"),
142+
MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
143+
MockResponse("<think> balalaba </think> \\boxed{35.00}"),
144+
]
145+
taskset_config = get_unittest_dataset_config("countdown")
146+
task = Task(
147+
workflow=MathBoxedWorkflow,
148+
format_args=taskset_config.format,
149+
rollout_args=taskset_config.rollout_args,
150+
workflow_args={
151+
"with_think": False,
152+
"format_score_coef": 0.2,
153+
},
154+
is_eval=False,
155+
raw_task={
156+
taskset_config.format.prompt_key: "",
157+
taskset_config.format.response_key: r"36",
158+
},
159+
)
160+
workflow = task.to_workflow(model=model)
161+
experiences = workflow.run()
162+
self.assertEqual(experiences[0].reward, 1.0)
163+
self.assertEqual(experiences[1].reward, 1.0)
164+
self.assertEqual(experiences[2].reward, 1.0)
165+
self.assertEqual(experiences[3].reward, 0.0)
166+
task_new = Task(
167+
workflow=MathBoxedWorkflow,
168+
format_args=taskset_config.format,
169+
rollout_args=taskset_config.rollout_args,
170+
workflow_args={
171+
"with_think": True,
172+
"format_score_coef": 0.2,
173+
},
174+
is_eval=False,
175+
raw_task={
176+
taskset_config.format.prompt_key: "",
177+
taskset_config.format.response_key: r"36",
178+
},
179+
)
180+
workflow.reset(task_new)
181+
workflow_new = task_new.to_workflow(model=model)
182+
experiences = workflow_new.run()
183+
self.assertEqual(experiences[0].reward, 1.0)
184+
self.assertEqual(experiences[1].reward, 0.8)
185+
self.assertEqual(experiences[2].reward, 0.8)
186+
self.assertEqual(experiences[3].reward, 0.0)
187+
137188
def test_gsm8k_workflow(self) -> None:
138189
model = MagicMock()
139190
model.chat.return_value = [

trinity/common/rewards/reward_fn.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from math_verify import LatexExtractionConfig, parse, verify
1010

1111
from trinity.utils.eval_utils import (
12+
compute_score,
1213
evaluate_equation,
1314
extract_solution,
1415
simple_answer_parser,
1516
validate_equation,
17+
validate_think_pattern,
1618
)
1719
from trinity.utils.log import get_logger
1820
from trinity.utils.registry import Registry
@@ -195,3 +197,33 @@ def __call__(
195197
return format_score
196198
except Exception as e: # noqa: F841
197199
return format_score
200+
201+
202+
@REWARD_FUNCTIONS.register_module("math_boxed_reward")
203+
class MathBoxedRewardFn(RewardFn):
204+
"""A reward function that rewards for math task."""
205+
206+
def __init__(
207+
self,
208+
) -> None:
209+
pass
210+
211+
def __call__( # type: ignore
212+
self,
213+
response: str,
214+
prompt: Optional[str] = None,
215+
truth: Optional[str] = None,
216+
return_dict: Optional[bool] = False,
217+
with_think: Optional[bool] = False,
218+
format_score_coef: Optional[float] = 0.1,
219+
) -> Union[float, dict]:
220+
accuracy_score = compute_score(response, truth)
221+
222+
format_score = 0.0
223+
if with_think and not validate_think_pattern(response):
224+
format_score = (format_score_coef or 0.1) * -1.0
225+
226+
if return_dict:
227+
return {"accuracy": accuracy_score, "format_score": format_score}
228+
229+
return accuracy_score + format_score

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""Workflow module"""
3+
from .customized_math_workflows import MathBoxedWorkflow
34
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
45
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
56
from .envs.webshop.webshop_workflow import WebShopWorkflow
@@ -14,4 +15,5 @@
1415
"WebShopWorkflow",
1516
"AlfworldWorkflow",
1617
"SciWorldWorkflow",
18+
"MathBoxedWorkflow",
1719
]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# -*- coding: utf-8 -*-
2+
"""We include the customized math workflows in this file."""
3+
4+
from dataclasses import asdict
5+
from typing import List
6+
7+
from trinity.common.experience import Experience
8+
from trinity.common.rewards.reward_fn import MathBoxedRewardFn
9+
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
10+
from trinity.utils.log import get_logger
11+
12+
logger = get_logger(__name__)
13+
14+
15+
@WORKFLOWS.register_module("math_boxed_workflow")
16+
class MathBoxedWorkflow(SimpleWorkflow):
17+
"""A workflow for math tasks that give answers in boxed format."""
18+
19+
def reset(self, task: Task):
20+
self.format_args = task.format_args
21+
self.system_prompt = task.format_args.system_prompt
22+
self.reply_prefix = task.format_args.reply_prefix
23+
24+
self.raw_task = task.raw_task
25+
self.task_desc = task.task_desc
26+
self.truth = task.truth
27+
28+
# Rollout args
29+
rollout_args = asdict(task.rollout_args)
30+
self.rollout_args = rollout_args
31+
self.is_eval = task.is_eval
32+
33+
self.workflow_args = task.workflow_args
34+
35+
self.use_base = self.workflow_args.get("use_base", False)
36+
self.with_think = self.workflow_args.get("with_think", False)
37+
self.format_score_coef = self.workflow_args.get("format_score_coef", 0.1)
38+
39+
default_prompt = (
40+
"""Please reason step by step, and put your final answer within \\boxed{}."""
41+
)
42+
43+
default_prompt_with_think = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results."""
44+
45+
if self.system_prompt is None:
46+
if self.with_think:
47+
self.system_prompt = default_prompt_with_think
48+
else:
49+
self.system_prompt = default_prompt
50+
51+
self.reward_fn = MathBoxedRewardFn()
52+
53+
def format_prompt(self):
54+
prompt_text = ""
55+
if self.system_prompt:
56+
prompt_text += "System:" + self.system_prompt
57+
prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n"
58+
else:
59+
prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n"
60+
return prompt_text
61+
62+
def run(self) -> List[Experience]:
63+
# TODO: Optimize the generate function
64+
if not self.use_base:
65+
messages = self.format_messages()
66+
else:
67+
prompt_text = self.format_prompt()
68+
69+
logger.debug("start chat")
70+
if not self.use_base:
71+
responses = self.model.chat(messages, **self.rollout_args)
72+
else:
73+
responses = self.model.generate([prompt_text], **self.rollout_args)
74+
75+
for response in responses:
76+
reward = MathBoxedRewardFn()( # type: ignore [misc]
77+
response=response.response_text, # type: ignore [arg-type]
78+
truth=self.truth,
79+
return_dict=self.is_eval,
80+
with_think=self.with_think,
81+
format_score_coef=self.format_score_coef,
82+
)
83+
logger.debug(
84+
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
85+
)
86+
if isinstance(reward, dict):
87+
if response.metrics is None:
88+
response.metrics = {}
89+
response.metrics.update(reward)
90+
reward = sum(reward.values())
91+
response.reward = reward
92+
return responses

0 commit comments

Comments
 (0)