Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest.mock import MagicMock

from tests.tools import get_unittest_dataset_config
from trinity.common.workflows import MathWorkflow, Workflow
from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow
from trinity.common.workflows.workflow import Task


Expand Down Expand Up @@ -134,6 +134,57 @@ def test_math_complex_workflow(self) -> None:
self.assertEqual(len(experiences), 1)
self.assertEqual(experiences[0].reward, 0.9)

def test_math_boxed_workflow(self) -> None:
model = MagicMock()
model.chat.return_value = [
MockResponse("<think> balabalabala 99 </think>\n \\boxed{36}"),
MockResponse("answer is \\boxed{36 }"),
MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
MockResponse("<think> balalaba </think> \\boxed{35.00}"),
]
taskset_config = get_unittest_dataset_config("countdown")
task = Task(
workflow=MathBoxedWorkflow,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
workflow_args={
"with_think": False,
"format_score_coef": 0.2,
},
is_eval=False,
raw_task={
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: r"36",
},
)
workflow = task.to_workflow(model=model)
experiences = workflow.run()
self.assertEqual(experiences[0].reward, 1.0)
self.assertEqual(experiences[1].reward, 1.0)
self.assertEqual(experiences[2].reward, 1.0)
self.assertEqual(experiences[3].reward, 0.0)
task_new = Task(
workflow=MathBoxedWorkflow,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
workflow_args={
"with_think": True,
"format_score_coef": 0.2,
},
is_eval=False,
raw_task={
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: r"36",
},
)
workflow.reset(task_new)
workflow_new = task_new.to_workflow(model=model)
experiences = workflow_new.run()
self.assertEqual(experiences[0].reward, 1.0)
self.assertEqual(experiences[1].reward, 0.8)
self.assertEqual(experiences[2].reward, 0.8)
self.assertEqual(experiences[3].reward, 0.0)

def test_gsm8k_workflow(self) -> None:
model = MagicMock()
model.chat.return_value = [
Expand Down
32 changes: 32 additions & 0 deletions trinity/common/rewards/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from math_verify import LatexExtractionConfig, parse, verify

from trinity.utils.eval_utils import (
compute_score,
evaluate_equation,
extract_solution,
simple_answer_parser,
validate_equation,
validate_think_pattern,
)
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
Expand Down Expand Up @@ -195,3 +197,33 @@ def __call__(
return format_score
except Exception as e: # noqa: F841
return format_score


@REWARD_FUNCTIONS.register_module("math_boxed_reward")
class MathBoxedRewardFn(RewardFn):
"""A reward function that rewards for math task."""

def __init__(
self,
) -> None:
pass

def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
return_dict: Optional[bool] = False,
with_think: Optional[bool] = False,
format_score_coef: Optional[float] = 0.1,
) -> Union[float, dict]:
accuracy_score = compute_score(response, truth)

format_score = 0.0
if with_think and not validate_think_pattern(response):
format_score = (format_score_coef or 0.1) * -1.0

if return_dict:
return {"accuracy": accuracy_score, "format_score": format_score}

return accuracy_score + format_score
2 changes: 2 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Workflow module"""
from .customized_math_workflows import MathBoxedWorkflow
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
Expand All @@ -14,4 +15,5 @@
"WebShopWorkflow",
"AlfworldWorkflow",
"SciWorldWorkflow",
"MathBoxedWorkflow",
]
92 changes: 92 additions & 0 deletions trinity/common/workflows/customized_math_workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
"""We include the customized math workflows in this file."""

from dataclasses import asdict
from typing import List

from trinity.common.experience import Experience
from trinity.common.rewards.reward_fn import MathBoxedRewardFn
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
from trinity.utils.log import get_logger

logger = get_logger(__name__)


@WORKFLOWS.register_module("math_boxed_workflow")
class MathBoxedWorkflow(SimpleWorkflow):
"""A workflow for math tasks that give answers in boxed format."""

def reset(self, task: Task):
self.format_args = task.format_args
self.system_prompt = task.format_args.system_prompt
self.reply_prefix = task.format_args.reply_prefix

self.raw_task = task.raw_task
self.task_desc = task.task_desc
self.truth = task.truth

# Rollout args
rollout_args = asdict(task.rollout_args)
self.rollout_args = rollout_args
self.is_eval = task.is_eval

self.workflow_args = task.workflow_args

self.use_base = self.workflow_args.get("use_base", False)
self.with_think = self.workflow_args.get("with_think", False)
self.format_score_coef = self.workflow_args.get("format_score_coef", 0.1)

default_prompt = (
"""Please reason step by step, and put your final answer within \\boxed{}."""
)

default_prompt_with_think = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results."""

if self.system_prompt is None:
if self.with_think:
self.system_prompt = default_prompt_with_think
else:
self.system_prompt = default_prompt

self.reward_fn = MathBoxedRewardFn()

def format_prompt(self):
prompt_text = ""
if self.system_prompt:
prompt_text += "System:" + self.system_prompt
prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n"
else:
prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n"
return prompt_text

def run(self) -> List[Experience]:
# TODO: Optimize the generate function
if not self.use_base:
messages = self.format_messages()
else:
prompt_text = self.format_prompt()

logger.debug("start chat")
if not self.use_base:
responses = self.model.chat(messages, **self.rollout_args)
else:
responses = self.model.generate([prompt_text], **self.rollout_args)

for response in responses:
reward = MathBoxedRewardFn()( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
truth=self.truth,
return_dict=self.is_eval,
with_think=self.with_think,
format_score_coef=self.format_score_coef,
)
logger.debug(
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
if isinstance(reward, dict):
if response.metrics is None:
response.metrics = {}
response.metrics.update(reward)
reward = sum(reward.values())
response.reward = reward
return responses
Loading