diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py
index 8cce2f9e85..586579439f 100644
--- a/tests/explorer/workflow_test.py
+++ b/tests/explorer/workflow_test.py
@@ -6,6 +6,7 @@
from tests.tools import get_unittest_dataset_config
from trinity.common.workflows import MathWorkflow
+from trinity.common.workflows.math_workflows import PREDEFINED_MATH_SYSTEM_PROMPTS
from trinity.common.workflows.workflow import Task
@@ -150,3 +151,92 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[1].reward, -0.1)
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)
+
+ def test_math_workflow_with_different_system_prompt(self) -> None:
+ model = MagicMock()
+ model.chat.return_value = [
+ MockResponse(" balabalabala 99 \n 36 "),
+ MockResponse(" 36.0 "),
+ MockResponse("Kim's total points are 6 + 30 = 36 "),
+ MockResponse(" balalaba 35.00 "),
+ MockResponse(" balabalabala 99 \n \\boxed{36}"),
+ MockResponse("\\boxed{36.0}"),
+ MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
+ MockResponse(" balalaba \\boxed{35.00}"),
+ ]
+ taskset_config = get_unittest_dataset_config("countdown")
+ task = Task(
+ workflow=MathWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.system_prompt: PREDEFINED_MATH_SYSTEM_PROMPTS[
+ "deepseek_like"
+ ],
+ taskset_config.format.prompt_key: "",
+ taskset_config.format.response_key: r"36",
+ },
+ )
+ task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["deepseek_like"]
+ workflow = task.to_workflow(model=model)
+ experiences = workflow.run()
+ # self.assertEqual(len(experiences), 1)
+ self.assertEqual(experiences[0].reward, 1.1)
+ self.assertEqual(experiences[1].reward, 0.9)
+ self.assertEqual(experiences[2].reward, 0.9)
+ self.assertEqual(experiences[3].reward, 0.1)
+ self.assertEqual(experiences[4].reward, 0.9)
+ self.assertEqual(experiences[5].reward, 0.9)
+ self.assertEqual(experiences[6].reward, 0.9)
+ self.assertEqual(experiences[7].reward, -0.1)
+ task_new = Task(
+ workflow=MathWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.system_prompt: PREDEFINED_MATH_SYSTEM_PROMPTS[
+ "boxed_with_think"
+ ],
+ taskset_config.format.prompt_key: "",
+ taskset_config.format.response_key: r"36",
+ },
+ )
+ task_new.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["boxed_with_think"]
+ workflow.reset(task_new)
+ workflow_new = task_new.to_workflow(model=model)
+ experiences = workflow_new.run()
+ self.assertEqual(experiences[0].reward, -0.1)
+ self.assertEqual(experiences[1].reward, -0.1)
+ self.assertEqual(experiences[2].reward, -0.1)
+ self.assertEqual(experiences[3].reward, -0.1)
+ self.assertEqual(experiences[4].reward, 1.0)
+ self.assertEqual(experiences[5].reward, 0.9)
+ self.assertEqual(experiences[6].reward, 0.9)
+ self.assertEqual(experiences[7].reward, 0.0)
+ task_new2 = Task(
+ workflow=MathWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.system_prompt: PREDEFINED_MATH_SYSTEM_PROMPTS[
+ "boxed_no_think"
+ ],
+ taskset_config.format.prompt_key: "",
+ taskset_config.format.response_key: r"36",
+ },
+ )
+ task_new2.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["boxed_no_think"]
+ workflow.reset(task_new2)
+ workflow_new2 = task_new2.to_workflow(model=model)
+ experiences = workflow_new2.run()
+ self.assertEqual(experiences[0].reward, -0.1)
+ self.assertEqual(experiences[1].reward, -0.1)
+ self.assertEqual(experiences[2].reward, -0.1)
+ self.assertEqual(experiences[3].reward, -0.1)
+ self.assertEqual(experiences[4].reward, 1.0)
+ self.assertEqual(experiences[5].reward, 1.0)
+ self.assertEqual(experiences[6].reward, 1.0)
+ self.assertEqual(experiences[7].reward, 0.0)
diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py
index 1b5906b8be..9b0c19a9bb 100644
--- a/trinity/common/rewards/reward_fn.py
+++ b/trinity/common/rewards/reward_fn.py
@@ -11,8 +11,10 @@
from trinity.utils.eval_utils import (
evaluate_equation,
extract_solution,
+ find_boxed_answer,
simple_answer_parser,
validate_equation,
+ validate_think_pattern,
)
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
@@ -152,6 +154,49 @@ def __call__( # type: ignore
return accuracy_score + format_score
+class MathBoxedRewardFn(RewardFn):
+ """Math Reward function that parse the boxed answer"""
+
+ def __init__(
+ self,
+ have_think_pattern: Optional[bool] = True,
+ ):
+ self.have_think_pattern = have_think_pattern
+
+ def __call__( # type: ignore
+ self,
+ response: str,
+ prompt: Optional[str] = None,
+ truth: Optional[str] = None,
+ return_dict: Optional[bool] = False,
+ ) -> Union[float, dict]:
+ answer = find_boxed_answer(response)
+ if answer is None:
+ if return_dict:
+ return {"accuracy": 0.0, "format_score": -0.1}
+ return -0.1
+
+ try:
+ reward = float(verify(parse(answer), parse(truth)))
+ except Exception as e:
+ print(f"verify failed: {e}, answer: {answer}, gold: {truth}")
+ logger.info(f"verify failed: {e}, answer: {answer}, gold: {truth}")
+ reward = 0.0
+
+ if self.have_think_pattern:
+ if validate_think_pattern(response):
+ reward += 0.0
+ else:
+ reward -= 0.1
+
+ if return_dict:
+ return {
+ "accuracy": 1.0 if reward > 0.9 else 0.0,
+ "format_score": 0.0 if reward >= 0.0 else -0.1,
+ }
+ return reward
+
+
@REWARD_FUNCTIONS.register_module("countdown_reward")
class CountDownRewardFn(RewardFn):
"""A reward function that rewards for countdown task."""
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index 92bf29a64e..32eaaab3cb 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -3,13 +3,16 @@
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
-from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task
+from .math_workflows import MathBasedModelWorkflow, MathWorkflow
+from .workflow import WORKFLOWS, BaseModelWorkflow, SimpleWorkflow, Task
__all__ = [
"Task",
"WORKFLOWS",
"SimpleWorkflow",
+ "BaseModelWorkflow",
"MathWorkflow",
+ "MathBasedModelWorkflow",
"WebShopWorkflow",
"AlfworldWorkflow",
"SciWorldWorkflow",
diff --git a/trinity/common/workflows/math_workflows.py b/trinity/common/workflows/math_workflows.py
new file mode 100644
index 0000000000..290d895678
--- /dev/null
+++ b/trinity/common/workflows/math_workflows.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+"""We include seprate the math workflows in this file."""
+from functools import partial
+from typing import List, Optional
+
+import openai
+
+from trinity.common.models.model import ModelWrapper
+from trinity.common.rewards.reward_fn import MathBoxedRewardFn, MathRewardFn
+from trinity.common.workflows.workflow import (
+ WORKFLOWS,
+ BaseModelWorkflow,
+ SimpleWorkflow,
+ Task,
+)
+
+PREDEFINED_MATH_SYSTEM_PROMPTS = {
+ "deepseek_like": """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e.,
+ reasoning process here
+ answer here .""",
+ "boxed_with_think": """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: \n ...your reasoning process here... \n first. You should always include your final answer in \\boxed{} as closed-form results.""",
+ "boxed_no_think": """Please reason step by step, and put your final answer within \\boxed{}.""",
+}
+
+
+@WORKFLOWS.register_module("math_workflow")
+class MathWorkflow(SimpleWorkflow):
+ """A workflow for math tasks"""
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List[openai.OpenAI]] = None,
+ ):
+ self.reset(task)
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+
+ def reset(self, task: Task):
+ if task.format_args.system_prompt is None:
+ task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["deepseek_like"]
+ if task.format_args.system_prompt in PREDEFINED_MATH_SYSTEM_PROMPTS.keys():
+ task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS[
+ task.format_args.system_prompt
+ ]
+
+ have_boxed_pattern = "boxed{" in task.format_args.system_prompt
+ if not have_boxed_pattern:
+ task.reward_fn = MathRewardFn
+ else:
+ have_think_pattern = (
+ "" in task.format_args.system_prompt
+ and "" in task.format_args.system_prompt
+ )
+ task.reward_fn = partial(MathBoxedRewardFn, have_think_pattern=have_think_pattern)
+
+ # call the SimpleWorkflow.reset
+ super().reset(task)
+
+
+@WORKFLOWS.register_module("math_based_model_workflow")
+class MathBasedModelWorkflow(BaseModelWorkflow):
+ """A workflow for math tasks, using base model"""
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List[openai.OpenAI]] = None,
+ ):
+ self.reset(task)
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+
+ def reset(self, task: Task):
+ if task.format_args.system_prompt is None:
+ task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS["deepseek_like"]
+ if task.format_args.system_prompt in PREDEFINED_MATH_SYSTEM_PROMPTS.keys():
+ task.format_args.system_prompt = PREDEFINED_MATH_SYSTEM_PROMPTS[
+ task.format_args.system_prompt
+ ]
+
+ have_boxed_pattern = "boxed{" in task.format_args.system_prompt
+ if not have_boxed_pattern:
+ task.reward_fn = MathRewardFn
+ else:
+ have_think_pattern = (
+ "" in task.format_args.system_prompt
+ and "" in task.format_args.system_prompt
+ )
+ task.reward_fn = partial(MathBoxedRewardFn, have_think_pattern=have_think_pattern)
+
+ # call the SimpleWorkflow.reset
+ super().reset(task)
diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py
index 9786bd6b77..620b42bc53 100644
--- a/trinity/common/workflows/workflow.py
+++ b/trinity/common/workflows/workflow.py
@@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
+from functools import partial
from typing import Any, List, Optional, Type, Union
import openai
@@ -13,7 +14,7 @@
from trinity.common.config import FormatConfig, GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
-from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn
+from trinity.common.rewards.reward_fn import RewardFn
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
@@ -176,6 +177,11 @@ def reset(self, task: Task):
reward_fn = task.reward_fn
if isinstance(reward_fn, type) and issubclass(reward_fn, RewardFn):
self.reward_fn: RewardFn = reward_fn()
+ elif isinstance(reward_fn, partial):
+ if isinstance(reward_fn.func, type) and issubclass(reward_fn.func, RewardFn):
+ self.reward_fn = reward_fn()
+ else:
+ raise ValueError("`reward_fn` as partial must wrap a subclass of `RewardFn`")
else:
raise ValueError("`reward_fn` must be a subclass of `RewardFn`")
# Rollout args
@@ -216,30 +222,37 @@ def run(self) -> List[Experience]:
return responses
-@WORKFLOWS.register_module("math_workflow")
-class MathWorkflow(SimpleWorkflow):
- """A workflow for math tasks as introduced in DeepSeek-R1."""
+@WORKFLOWS.register_module("basemodel_workflow")
+class BaseModelWorkflow(SimpleWorkflow):
+ """A workflow for simple single-round task, using base model"""
- def __init__(
- self,
- model: ModelWrapper,
- task: Task,
- auxiliary_models: Optional[List[openai.OpenAI]] = None,
- ):
- self.reset(task)
- super().__init__(
- model=model,
- task=task,
- auxiliary_models=auxiliary_models,
- )
+ def format_prompt(self):
+ prompt_text = ""
+ if self.system_prompt:
+ prompt_text += self.system_prompt
+ prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n"
+ else:
+ prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n"
+ return prompt_text
- def reset(self, task: Task):
- if task.reward_fn is None:
- task.reward_fn = MathRewardFn
- if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None:
- task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e.,
- reasoning process here
- answer here .
-"""
- # call the SimpleWorkflow.reset
- super().reset(task)
+ def run(self) -> List[Experience]:
+ prompt_text = self.format_prompt()
+
+ logger.debug("start generation")
+ responses = self.model.generate([prompt_text], **self.rollout_args)
+ for response in responses:
+ reward = self.reward_fn(
+ response=response.response_text, # type: ignore [arg-type]
+ truth=self.truth,
+ return_dict=self.is_eval,
+ )
+ logger.debug(
+ f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}"
+ )
+ if isinstance(reward, dict):
+ if response.metrics is None:
+ response.metrics = {}
+ response.metrics.update(reward)
+ reward = sum(reward.values())
+ response.reward = reward
+ return responses
diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py
index e3aa216eda..162603b036 100644
--- a/trinity/utils/eval_utils.py
+++ b/trinity/utils/eval_utils.py
@@ -76,3 +76,19 @@ def evaluate_equation(equation_str):
return result
except Exception as e: # noqa: F841
return None
+
+
+def validate_think_pattern(text):
+ """Validate whether the tag is properly formatted."""
+ start_tag = ""
+ end_tag = ""
+
+ start_count = text.count(start_tag)
+ end_count = text.count(end_tag)
+
+ if start_count == 1 and end_count == 1:
+ start_pos = text.find(start_tag)
+ end_pos = text.find(end_tag)
+ if start_pos < end_pos:
+ return True
+ return False