Skip to content

Commit 7bc465c

Browse files
committed
Merge branch 'main' into algorithm_dev
2 parents eddf4e4 + c05296d commit 7bc465c

File tree

10 files changed

+428
-6
lines changed

10 files changed

+428
-6
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
[![paper](http://img.shields.io/badge/cs.LG-2505.17826-B31B1B?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2505.17826)
1414
[![doc](https://img.shields.io/badge/Docs-blue?logo=markdown)](https://modelscope.github.io/Trinity-RFT/)
15-
[![pypi](https://img.shields.io/pypi/v/trinity-rft?logo=pypi&color=026cad)](https://pypi.org/project/trinity-rft/0.1.0/)
15+
[![pypi](https://img.shields.io/pypi/v/trinity-rft?logo=pypi&color=026cad)](https://pypi.org/project/trinity-rft/0.1.1/)
1616
![license](https://img.shields.io/badge/license-Apache--2.0-000000.svg)
1717

1818
</div>
@@ -159,7 +159,7 @@ pip install -e .\[flash_attn\]
159159
Installation using pip:
160160

161161
```shell
162-
pip install trinity-rft==0.1.0
162+
pip install trinity-rft==0.1.1
163163
```
164164

165165
Installation from docker:

docs/sphinx_doc/source/main.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ pip install flash-attn -v
130130
Installation using pip:
131131

132132
```shell
133-
pip install trinity-rft==0.1.0
133+
pip install trinity-rft==0.1.1
134134
```
135135

136136
Installation from docker:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "trinity-rft"
7-
version = "0.1.0"
7+
version = "0.1.1"
88
authors = [
99
{name="Trinity-RFT Team", email="[email protected]"},
1010
]

tests/explorer/workflow_test.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest.mock import MagicMock
66

77
from tests.tools import get_unittest_dataset_config
8-
from trinity.common.workflows import MathWorkflow, Workflow
8+
from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow
99
from trinity.common.workflows.workflow import Task
1010

1111

@@ -134,6 +134,57 @@ def test_math_complex_workflow(self) -> None:
134134
self.assertEqual(len(experiences), 1)
135135
self.assertEqual(experiences[0].reward, 0.9)
136136

137+
def test_math_boxed_workflow(self) -> None:
138+
model = MagicMock()
139+
model.chat.return_value = [
140+
MockResponse("<think> balabalabala 99 </think>\n \\boxed{36}"),
141+
MockResponse("answer is \\boxed{36 }"),
142+
MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
143+
MockResponse("<think> balalaba </think> \\boxed{35.00}"),
144+
]
145+
taskset_config = get_unittest_dataset_config("countdown")
146+
task = Task(
147+
workflow=MathBoxedWorkflow,
148+
format_args=taskset_config.format,
149+
rollout_args=taskset_config.rollout_args,
150+
workflow_args={
151+
"with_think": False,
152+
"format_score_coef": 0.2,
153+
},
154+
is_eval=False,
155+
raw_task={
156+
taskset_config.format.prompt_key: "",
157+
taskset_config.format.response_key: r"36",
158+
},
159+
)
160+
workflow = task.to_workflow(model=model)
161+
experiences = workflow.run()
162+
self.assertEqual(experiences[0].reward, 1.0)
163+
self.assertEqual(experiences[1].reward, 1.0)
164+
self.assertEqual(experiences[2].reward, 1.0)
165+
self.assertEqual(experiences[3].reward, 0.0)
166+
task_new = Task(
167+
workflow=MathBoxedWorkflow,
168+
format_args=taskset_config.format,
169+
rollout_args=taskset_config.rollout_args,
170+
workflow_args={
171+
"with_think": True,
172+
"format_score_coef": 0.2,
173+
},
174+
is_eval=False,
175+
raw_task={
176+
taskset_config.format.prompt_key: "",
177+
taskset_config.format.response_key: r"36",
178+
},
179+
)
180+
workflow.reset(task_new)
181+
workflow_new = task_new.to_workflow(model=model)
182+
experiences = workflow_new.run()
183+
self.assertEqual(experiences[0].reward, 1.0)
184+
self.assertEqual(experiences[1].reward, 0.8)
185+
self.assertEqual(experiences[2].reward, 0.8)
186+
self.assertEqual(experiences[3].reward, 0.0)
187+
137188
def test_gsm8k_workflow(self) -> None:
138189
model = MagicMock()
139190
model.chat.return_value = [

trinity/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# -*- coding: utf-8 -*-
22
"""Trinity-RFT (Reinforcement Fine-Tuning)"""
33

4-
__version__ = "0.1.0"
4+
__version__ = "0.1.1"

trinity/buffer/reader/file_reader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import datasets
66
import transformers
77
from datasets import Dataset, load_dataset
8+
from tqdm import tqdm
89

910
from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
1011
from trinity.buffer.buffer_reader import BufferReader
@@ -35,11 +36,20 @@ def __init__(self, dataset: Dataset, max_epoch: int = 1, offset: int = 0):
3536
for _ in range(self.current_offset):
3637
next(self.iter)
3738

39+
# Initialize tqdm progress bar
40+
self.total_steps = self.dataset_size * self.max_epoch
41+
self.progress_bar = tqdm(
42+
total=self.total_steps,
43+
initial=self.current_epoch * self.dataset_size + self.current_offset,
44+
desc="Dataset Progressing",
45+
)
46+
3847
def read_batch(self, batch_size: int) -> List:
3948
batch = []
4049

4150
while len(batch) < batch_size:
4251
try:
52+
self.progress_bar.update(1)
4353
item = next(self.iter)
4454
batch.append(item)
4555
self.current_offset += 1
@@ -49,7 +59,9 @@ def read_batch(self, batch_size: int) -> List:
4959
self.current_offset = 0
5060

5161
if self.current_epoch >= self.max_epoch:
62+
self.progress_bar.close()
5263
raise StopIteration
64+
# Step to the next epoch
5365
self.iter = iter(self.dataset)
5466
return batch
5567

trinity/common/rewards/reward_fn.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from math_verify import LatexExtractionConfig, parse, verify
1010

1111
from trinity.utils.eval_utils import (
12+
compute_score,
1213
evaluate_equation,
1314
extract_solution,
1415
simple_answer_parser,
1516
validate_equation,
17+
validate_think_pattern,
1618
)
1719
from trinity.utils.log import get_logger
1820
from trinity.utils.registry import Registry
@@ -195,3 +197,33 @@ def __call__(
195197
return format_score
196198
except Exception as e: # noqa: F841
197199
return format_score
200+
201+
202+
@REWARD_FUNCTIONS.register_module("math_boxed_reward")
203+
class MathBoxedRewardFn(RewardFn):
204+
"""A reward function that rewards for math task."""
205+
206+
def __init__(
207+
self,
208+
) -> None:
209+
pass
210+
211+
def __call__( # type: ignore
212+
self,
213+
response: str,
214+
prompt: Optional[str] = None,
215+
truth: Optional[str] = None,
216+
return_dict: Optional[bool] = False,
217+
with_think: Optional[bool] = False,
218+
format_score_coef: Optional[float] = 0.1,
219+
) -> Union[float, dict]:
220+
accuracy_score = compute_score(response, truth)
221+
222+
format_score = 0.0
223+
if with_think and not validate_think_pattern(response):
224+
format_score = (format_score_coef or 0.1) * -1.0
225+
226+
if return_dict:
227+
return {"accuracy": accuracy_score, "format_score": format_score}
228+
229+
return accuracy_score + format_score

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""Workflow module"""
3+
from .customized_math_workflows import MathBoxedWorkflow
34
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
45
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
56
from .envs.webshop.webshop_workflow import WebShopWorkflow
@@ -14,4 +15,5 @@
1415
"WebShopWorkflow",
1516
"AlfworldWorkflow",
1617
"SciWorldWorkflow",
18+
"MathBoxedWorkflow",
1719
]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# -*- coding: utf-8 -*-
2+
"""We include the customized math workflows in this file."""
3+
4+
from dataclasses import asdict
5+
from typing import List
6+
7+
from trinity.common.experience import Experience
8+
from trinity.common.rewards.reward_fn import MathBoxedRewardFn
9+
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
10+
from trinity.utils.log import get_logger
11+
12+
logger = get_logger(__name__)
13+
14+
15+
@WORKFLOWS.register_module("math_boxed_workflow")
16+
class MathBoxedWorkflow(SimpleWorkflow):
17+
"""A workflow for math tasks that give answers in boxed format."""
18+
19+
def reset(self, task: Task):
20+
self.format_args = task.format_args
21+
self.system_prompt = task.format_args.system_prompt
22+
self.reply_prefix = task.format_args.reply_prefix
23+
24+
self.raw_task = task.raw_task
25+
self.task_desc = task.task_desc
26+
self.truth = task.truth
27+
28+
# Rollout args
29+
rollout_args = asdict(task.rollout_args)
30+
self.rollout_args = rollout_args
31+
self.is_eval = task.is_eval
32+
33+
self.workflow_args = task.workflow_args
34+
35+
self.use_base = self.workflow_args.get("use_base", False)
36+
self.with_think = self.workflow_args.get("with_think", False)
37+
self.format_score_coef = self.workflow_args.get("format_score_coef", 0.1)
38+
39+
default_prompt = (
40+
"""Please reason step by step, and put your final answer within \\boxed{}."""
41+
)
42+
43+
default_prompt_with_think = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results."""
44+
45+
if self.system_prompt is None:
46+
if self.with_think:
47+
self.system_prompt = default_prompt_with_think
48+
else:
49+
self.system_prompt = default_prompt
50+
51+
self.reward_fn = MathBoxedRewardFn()
52+
53+
def format_prompt(self):
54+
prompt_text = ""
55+
if self.system_prompt:
56+
prompt_text += "System:" + self.system_prompt
57+
prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n"
58+
else:
59+
prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n"
60+
return prompt_text
61+
62+
def run(self) -> List[Experience]:
63+
# TODO: Optimize the generate function
64+
if not self.use_base:
65+
messages = self.format_messages()
66+
else:
67+
prompt_text = self.format_prompt()
68+
69+
logger.debug("start chat")
70+
if not self.use_base:
71+
responses = self.model.chat(messages, **self.rollout_args)
72+
else:
73+
responses = self.model.generate([prompt_text], **self.rollout_args)
74+
75+
for response in responses:
76+
reward = MathBoxedRewardFn()( # type: ignore [misc]
77+
response=response.response_text, # type: ignore [arg-type]
78+
truth=self.truth,
79+
return_dict=self.is_eval,
80+
with_think=self.with_think,
81+
format_score_coef=self.format_score_coef,
82+
)
83+
logger.debug(
84+
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
85+
)
86+
if isinstance(reward, dict):
87+
if response.metrics is None:
88+
response.metrics = {}
89+
response.metrics.update(reward)
90+
reward = sum(reward.values())
91+
response.reward = reward
92+
return responses

0 commit comments

Comments
 (0)