Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tests.tools import get_unittest_dataset_config
from trinity.common.workflows import MathWorkflow
from trinity.common.workflows.math_workflows import PREDEFINED_MATH_SYSTEM_PROMPTS
from trinity.common.workflows.workflow import Task


Expand Down Expand Up @@ -150,3 +151,92 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[1].reward, -0.1)
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)

def test_math_workflow_with_different_system_prompt(self) -> None:
model = MagicMock()
model.chat.return_value = [
MockResponse("<think> balabalabala 99 </think>\n<answer> 36 </answer>"),
MockResponse("<answer> 36.0 </answer>"),
MockResponse("<answer>Kim's total points are 6 + 30 = 36 </answer>"),
MockResponse("<think> balalaba </think><answer> 35.00 </answer>"),
MockResponse("<think> balabalabala 99 </think>\n \\boxed{36}"),
MockResponse("\\boxed{36.0}"),
MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
MockResponse("<think> balalaba </think> \\boxed{35.00}"),
]
taskset_config = get_unittest_dataset_config("countdown")
task = Task(
workflow=MathWorkflow,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
is_eval=False,
raw_task={
taskset_config.format.system_prompt: PREDEFINED_MATH_SYSTEM_PROMPTS[
"deepseek_like"
],
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: r"36",
},
)
task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["deepseek_like"]
workflow = task.to_workflow(model=model)
experiences = workflow.run()
# self.assertEqual(len(experiences), 1)
self.assertEqual(experiences[0].reward, 1.1)
self.assertEqual(experiences[1].reward, 0.9)
self.assertEqual(experiences[2].reward, 0.9)
self.assertEqual(experiences[3].reward, 0.1)
self.assertEqual(experiences[4].reward, 0.9)
self.assertEqual(experiences[5].reward, 0.9)
self.assertEqual(experiences[6].reward, 0.9)
self.assertEqual(experiences[7].reward, -0.1)
task_new = Task(
workflow=MathWorkflow,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
is_eval=False,
raw_task={
taskset_config.format.system_prompt: PREDEFINED_MATH_SYSTEM_PROMPTS[
"boxed_with_think"
],
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: r"36",
},
)
task_new.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["boxed_with_think"]
workflow.reset(task_new)
workflow_new = task_new.to_workflow(model=model)
experiences = workflow_new.run()
self.assertEqual(experiences[0].reward, -0.1)
self.assertEqual(experiences[1].reward, -0.1)
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, -0.1)
self.assertEqual(experiences[4].reward, 1.0)
self.assertEqual(experiences[5].reward, 0.9)
self.assertEqual(experiences[6].reward, 0.9)
self.assertEqual(experiences[7].reward, 0.0)
task_new2 = Task(
workflow=MathWorkflow,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
is_eval=False,
raw_task={
taskset_config.format.system_prompt: PREDEFINED_MATH_SYSTEM_PROMPTS[
"boxed_no_think"
],
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: r"36",
},
)
task_new2.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["boxed_no_think"]
workflow.reset(task_new2)
workflow_new2 = task_new2.to_workflow(model=model)
experiences = workflow_new2.run()
self.assertEqual(experiences[0].reward, -0.1)
self.assertEqual(experiences[1].reward, -0.1)
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, -0.1)
self.assertEqual(experiences[4].reward, 1.0)
self.assertEqual(experiences[5].reward, 1.0)
self.assertEqual(experiences[6].reward, 1.0)
self.assertEqual(experiences[7].reward, 0.0)
45 changes: 45 additions & 0 deletions trinity/common/rewards/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from trinity.utils.eval_utils import (
evaluate_equation,
extract_solution,
find_boxed_answer,
simple_answer_parser,
validate_equation,
validate_think_pattern,
)
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
Expand Down Expand Up @@ -152,6 +154,49 @@ def __call__( # type: ignore
return accuracy_score + format_score


class MathBoxedRewardFn(RewardFn):
"""Math Reward function that parse the boxed answer"""

def __init__(
self,
have_think_pattern: Optional[bool] = True,
):
self.have_think_pattern = have_think_pattern

def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
return_dict: Optional[bool] = False,
) -> Union[float, dict]:
answer = find_boxed_answer(response)
if answer is None:
if return_dict:
return {"accuracy": 0.0, "format_score": -0.1}
return -0.1

try:
reward = float(verify(parse(answer), parse(truth)))
except Exception as e:
print(f"verify failed: {e}, answer: {answer}, gold: {truth}")
logger.info(f"verify failed: {e}, answer: {answer}, gold: {truth}")
reward = 0.0

if self.have_think_pattern:
if validate_think_pattern(response):
reward += 0.0
else:
reward -= 0.1

if return_dict:
return {
"accuracy": 1.0 if reward > 0.9 else 0.0,
"format_score": 0.0 if reward >= 0.0 else -0.1,
}
return reward


@REWARD_FUNCTIONS.register_module("countdown_reward")
class CountDownRewardFn(RewardFn):
"""A reward function that rewards for countdown task."""
Expand Down
5 changes: 4 additions & 1 deletion trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task
from .math_workflows import MathBasedModelWorkflow, MathWorkflow
from .workflow import WORKFLOWS, BaseModelWorkflow, SimpleWorkflow, Task

__all__ = [
"Task",
"WORKFLOWS",
"SimpleWorkflow",
"BaseModelWorkflow",
"MathWorkflow",
"MathBasedModelWorkflow",
"WebShopWorkflow",
"AlfworldWorkflow",
"SciWorldWorkflow",
Expand Down
101 changes: 101 additions & 0 deletions trinity/common/workflows/math_workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
"""We include seprate the math workflows in this file."""
from functools import partial
from typing import List, Optional

import openai

from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import MathBoxedRewardFn, MathRewardFn
from trinity.common.workflows.workflow import (
WORKFLOWS,
BaseModelWorkflow,
SimpleWorkflow,
Task,
)

PREDEFINED_MATH_SYSTEM_PROMPTS = {
"deepseek_like": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think>
<answer> answer here </answer>.""",
"boxed_with_think": """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results.""",
"boxed_no_think": """Please reason step by step, and put your final answer within \\boxed{}.""",
}


@WORKFLOWS.register_module("math_workflow")
class MathWorkflow(SimpleWorkflow):
"""A workflow for math tasks"""

def __init__(
self,
model: ModelWrapper,
task: Task,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a workflow_args: Dict field to the __init__ method of Workflow interface

):
self.reset(task)
super().__init__(
model=model,
task=task,
auxiliary_models=auxiliary_models,
)

def reset(self, task: Task):
if task.format_args.system_prompt is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The system_prompt here is a bit confusing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is basically the same as the old MathWorkflow. Any recommendation on how to improve it?

task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["deepseek_like"]
if task.format_args.system_prompt in PREDEFINED_MATH_SYSTEM_PROMPTS.keys():
task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS[
task.format_args.system_prompt
]

have_boxed_pattern = "boxed{" in task.format_args.system_prompt
if not have_boxed_pattern:
task.reward_fn = MathRewardFn
else:
have_think_pattern = (
"<think>" in task.format_args.system_prompt
and "</think>" in task.format_args.system_prompt
)
task.reward_fn = partial(MathBoxedRewardFn, have_think_pattern=have_think_pattern)

# call the SimpleWorkflow.reset
super().reset(task)


@WORKFLOWS.register_module("math_based_model_workflow")
class MathBasedModelWorkflow(BaseModelWorkflow):
"""A workflow for math tasks, using base model"""

def __init__(
self,
model: ModelWrapper,
task: Task,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.reset(task)
super().__init__(
model=model,
task=task,
auxiliary_models=auxiliary_models,
)

def reset(self, task: Task):
if task.format_args.system_prompt is None:
task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["deepseek_like"]
if task.format_args.system_prompt in PREDEFINED_MATH_SYSTEM_PROMPTS.keys():
task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS[
task.format_args.system_prompt
]

have_boxed_pattern = "boxed{" in task.format_args.system_prompt
if not have_boxed_pattern:
task.reward_fn = MathRewardFn
else:
have_think_pattern = (
"</think>" in task.format_args.system_prompt
and "</think>" in task.format_args.system_prompt
)
task.reward_fn = partial(MathBoxedRewardFn, have_think_pattern=have_think_pattern)

# call the SimpleWorkflow.reset
super().reset(task)
65 changes: 39 additions & 26 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from functools import partial
from typing import Any, List, Optional, Type, Union

import openai
Expand All @@ -13,7 +14,7 @@
from trinity.common.config import FormatConfig, GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn
from trinity.common.rewards.reward_fn import RewardFn
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

Expand Down Expand Up @@ -176,6 +177,11 @@ def reset(self, task: Task):
reward_fn = task.reward_fn
if isinstance(reward_fn, type) and issubclass(reward_fn, RewardFn):
self.reward_fn: RewardFn = reward_fn()
elif isinstance(reward_fn, partial):
if isinstance(reward_fn.func, type) and issubclass(reward_fn.func, RewardFn):
self.reward_fn = reward_fn()
else:
raise ValueError("`reward_fn` as partial must wrap a subclass of `RewardFn`")
else:
raise ValueError("`reward_fn` must be a subclass of `RewardFn`")
# Rollout args
Expand Down Expand Up @@ -216,30 +222,37 @@ def run(self) -> List[Experience]:
return responses


@WORKFLOWS.register_module("math_workflow")
class MathWorkflow(SimpleWorkflow):
"""A workflow for math tasks as introduced in DeepSeek-R1."""
@WORKFLOWS.register_module("basemodel_workflow")
class BaseModelWorkflow(SimpleWorkflow):
"""A workflow for simple single-round task, using base model"""

def __init__(
self,
model: ModelWrapper,
task: Task,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.reset(task)
super().__init__(
model=model,
task=task,
auxiliary_models=auxiliary_models,
)
def format_prompt(self):
prompt_text = ""
if self.system_prompt:
prompt_text += self.system_prompt
prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n"
else:
prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n"
return prompt_text

def reset(self, task: Task):
if task.reward_fn is None:
task.reward_fn = MathRewardFn
if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None:
task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think>
<answer> answer here </answer>.
"""
# call the SimpleWorkflow.reset
super().reset(task)
def run(self) -> List[Experience]:
prompt_text = self.format_prompt()

logger.debug("start generation")
responses = self.model.generate([prompt_text], **self.rollout_args)
for response in responses:
reward = self.reward_fn(
response=response.response_text, # type: ignore [arg-type]
truth=self.truth,
return_dict=self.is_eval,
)
logger.debug(
f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}"
)
if isinstance(reward, dict):
if response.metrics is None:
response.metrics = {}
response.metrics.update(reward)
reward = sum(reward.values())
response.reward = reward
return responses
16 changes: 16 additions & 0 deletions trinity/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,19 @@ def evaluate_equation(equation_str):
return result
except Exception as e: # noqa: F841
return None


def validate_think_pattern(text):
"""Validate whether the <think> </think> tag is properly formatted."""
start_tag = "<think>"
end_tag = "</think>"

start_count = text.count(start_tag)
end_count = text.count(end_tag)

if start_count == 1 and end_count == 1:
start_pos = text.find(start_tag)
end_pos = text.find(end_tag)
if start_pos < end_pos:
return True
return False