Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"openai",
"jsonlines",
"sortedcontainers",
"word2number",
]

[project.scripts]
Expand All @@ -66,7 +67,7 @@ dev = [
"pytest>=8.0.0",
"pytest-json-ctrf",
"parameterized",
"matplotlib"
"matplotlib",
]

doc = [
Expand Down
31 changes: 31 additions & 0 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from trinity.common.rewards import RMGalleryFn
from trinity.common.workflows import (
MathBoxedWorkflow,
MathEvalWorkflow,
MathRMWorkflow,
MathWorkflow,
Workflow,
Expand Down Expand Up @@ -274,6 +275,36 @@ def test_rm_gallery_workflow(self) -> None:
self.assertEqual(experiences[2].reward, 1.0)
self.assertEqual(experiences[3].reward, 0.0)

def test_math_eval_workflow(self) -> None:
model = MagicMock()
model.chat.return_value = [
MockResponse("My step-by-step reasoning leads to the answer \boxed{36}"),
MockResponse("Here is the answer of \boxed{36.0}"),
MockResponse("I made a mistake, the answer is \boxed{42}"),
MockResponse("The answer is 36, but I forgot the box."),
]

taskset_config = get_unittest_dataset_config("countdown")
task = Task(
workflow=MathEvalWorkflow,
is_eval=True,
format_args=taskset_config.format,
raw_task={
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: "36",
},
)

workflow = task.to_workflow(model=model)
experiences = workflow.run()
self.assertEqual(len(experiences), 4)
expected_accuracies = [1.0, 1.0, 0.0, 0.0]
for i, (exp, expected_acc) in enumerate(zip(experiences, expected_accuracies)):
with self.subTest(f"Response {i}"):
self.assertEqual(exp.reward, 0.0)
assert exp.metrics is not None, f"Metrics for response {i} should not be None"
self.assertEqual(exp.metrics["accuracy"], expected_acc)

def test_workflow_resettable(self) -> None:
model = MagicMock()
json_task = Task(
Expand Down
1 change: 1 addition & 0 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ buffer:
path: 'placeholder'
split: 'train'
default_workflow_type: ''
default_eval_workflow_type: ''
default_reward_fn_type: ''
explorer:
eval_interval: 100
Expand Down
1 change: 1 addition & 0 deletions tests/test_data/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ buffer:
storage_type: file
path: ''
default_workflow_type: ''
default_eval_workflow_type: ''
default_reward_fn_type: ''
explorer:
runner_num: 8
Expand Down
97 changes: 97 additions & 0 deletions tests/utils/eval_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
"""Test for the evaluation utils module."""

import unittest

from trinity.utils.eval_utils import is_equiv
from trinity.utils.math_eval_utils import extract_answer, verify_math_answer


class TestMathEvalUtils(unittest.TestCase):
def test_extract_answer(self):
test_cases = [
("The answer is \\boxed{42}", "42", "Basic boxed extraction"),
("The result is \\boxed{\\frac{1}{2}}", "\\frac{1}{2}", "Boxed with LaTeX"),
("Therefore, the final answer is 100.", "100", "English 'answer is' extraction"),
("My final answer is: 3.14", "3.14", "English 'answer is' with colon"),
("所以,答案是x^2", "x^2", "Chinese 'answer is' extraction"),
(
"The cost is 10 dollars and the profit is 20 dollars.",
"20",
"Extract the last number",
),
(
"There are 1,000 apples and 2,000 oranges.",
"2000",
"Extract the last number with commas",
),
("The probability is 0.75.", "0.75", "Extract the last decimal"),
("This sentence has no answer.", None, "No answer case"),
("The box is empty \\boxed{}", None, "Empty boxed"),
(12345, None, "Input is not a string"),
]

for i, (input_str, expected_output, description) in enumerate(test_cases):
with self.subTest(f"Case {i+1}: {description}"):
actual_output = extract_answer(input_str)
self.assertEqual(
actual_output,
expected_output,
f"Failed on input: '{input_str}'\nExpected: '{expected_output}', Got: '{actual_output}'",
)

def test_verify_math_answer(self):
test_cases = [
("The answer is \\boxed{42}", "42", True, "Simple integer equality"),
("The result is 1,000.", "1000", True, "Number with commas"),
("The answer is -50.", "-50", True, "Negative number equality"),
("The solution is 5", "x=5", True, "Equivalence of value and equation"),
("The answer is \\boxed{42}", "43", False, "Simple numerical inequality"),
("The answer is \\boxed{x+1}", "x-1", False, "Symbolic expression inequality"),
(
"The matrix is \\boxed{\\begin{pmatrix}1 & 1 \\\\ 0 & 1\\end{pmatrix}}",
"\\begin{pmatrix}1&0\\\\0&1\\end{pmatrix}",
False,
"Matrix inequality",
),
("The speed is 50 km/h", "50", True, "Judgment after stripping units"),
]

for i, (response, ground_truth, expected_correct, description) in enumerate(test_cases):
with self.subTest(f"Case {i+1}: {description}"):
accuracy, details = verify_math_answer(response, ground_truth)
is_correct = accuracy == 1.0
self.assertEqual(
is_correct,
expected_correct,
f"Failed on response: '{response}' with truth: '{ground_truth}'\n"
f"Expected correct: {expected_correct}, Got: {is_correct}\nDetails: {details}",
)


if __name__ == "__main__":
unittest.main()


class TestEvalUtils(unittest.TestCase):
def test_is_equiv(self):
test_cases = [
# str1, str2, expected_output, description
(" 123 ", "123", True, "Equivalence with whitespace"),
("50%", "50", True, "Equivalence with percentage sign"),
("$50", "50", True, "Equivalence with dollar sign"),
("hello", "world", False, "Basic inequality"),
("123", "1234", False, "Numerical inequality"),
(None, None, True, "Both inputs are None"),
("Some string", None, False, "One input is None (str1)"),
(None, "Some string", False, "One input is None (str2)"),
]

for i, (str1, str2, expected_output, description) in enumerate(test_cases):
with self.subTest(f"Case {i+1}: {description}"):
actual_output = is_equiv(str1, str2)
self.assertEqual(
actual_output,
expected_output,
f"Failed on inputs: ('{str1}', '{str2}')\nExpected: {expected_output}, Got: {actual_output}",
)
16 changes: 11 additions & 5 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):

self.task_type = meta.task_type
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore
self.default_eval_workflow_cls = None
if getattr(meta, "default_eval_workflow_type", None):
self.default_eval_workflow_cls = WORKFLOWS.get(meta.default_eval_workflow_type)
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore

def read(
Expand All @@ -297,11 +300,14 @@ def read(
tasks = []
samples = self.dataset.read_batch(batch_size)
for sample in samples:
workflow_class = (
WORKFLOWS.get(sample[self.workflow_key])
if self.workflow_key in sample
else self.default_workflow_cls
)
if self.task_type == TaskType.EVAL and self.default_eval_workflow_cls:
workflow_class = self.default_eval_workflow_cls
else:
workflow_class = (
WORKFLOWS.get(sample[self.workflow_key])
if self.workflow_key in sample
else self.default_workflow_cls
)
reward_fn = (
REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
if self.reward_fn_key in sample
Expand Down
10 changes: 10 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class StorageConfig:

# used for rollout tasks
default_workflow_type: Optional[str] = None
default_eval_workflow_type: Optional[str] = None
default_reward_fn_type: Optional[str] = None
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)
Expand Down Expand Up @@ -276,6 +277,7 @@ class ExplorerInput:
eval_tasksets: List[StorageConfig] = field(default_factory=list)
# The following args provide default values for the corresponding args in `taskset` and `eval_tasksets`
default_workflow_type: Optional[str] = None
default_eval_workflow_type: Optional[str] = None
default_reward_fn_type: Optional[str] = None
system_prompt: Optional[str] = None
reply_prefix: Optional[str] = None
Expand Down Expand Up @@ -479,6 +481,10 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.explorer_input.taskset.default_workflow_type = (
self.buffer.explorer_input.default_workflow_type
)
if self.buffer.explorer_input.taskset.default_eval_workflow_type is None:
self.buffer.explorer_input.taskset.default_eval_workflow_type = (
self.buffer.explorer_input.default_eval_workflow_type
)
if self.buffer.explorer_input.taskset.default_reward_fn_type is None:
self.buffer.explorer_input.taskset.default_reward_fn_type = (
self.buffer.explorer_input.default_reward_fn_type
Expand All @@ -504,6 +510,10 @@ def _check_buffer(self) -> None: # noqa: C901
dataset.name = f"eval_taskset_{idx}"
if dataset.default_workflow_type is None:
dataset.default_workflow_type = self.buffer.explorer_input.default_workflow_type
if dataset.default_eval_workflow_type is None:
dataset.default_eval_workflow_type = (
self.buffer.explorer_input.default_eval_workflow_type
)
if dataset.default_reward_fn_type is None:
dataset.default_reward_fn_type = self.buffer.explorer_input.default_reward_fn_type
if dataset.format.system_prompt is None:
Expand Down
2 changes: 2 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .eval_workflow import MathEvalWorkflow
from .math_rm_workflow import MathRMWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow

Expand All @@ -20,4 +21,5 @@
"MathBoxedWorkflow",
"MathRMWorkflow",
"ToolCallWorkflow",
"MathEvalWorkflow",
]
86 changes: 86 additions & 0 deletions trinity/common/workflows/eval_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""Evaluation Workflow Class"""

from dataclasses import asdict
from typing import List, Optional

import openai

from trinity.common.config import GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
from trinity.utils.log import get_logger
from trinity.utils.math_eval_utils import verify_math_answer

logger = get_logger(__name__)


@WORKFLOWS.register_module("math_eval_workflow")
class MathEvalWorkflow(Workflow):
"""
A workflow for standard math evaluation.

The evaluation standard and prompting style are follow the Qwen2.5-Math
model's evaluation methodology. For more details on their approach, see:
https://github.com/QwenLM/Qwen2.5-Math
"""

def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)

self.raw_task = task.raw_task
self.truth = task.truth

# TODO: customize the config in the yaml
self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1))

@property
def resettable(self):
return False

def format_messages(self):
"""Format message for the evaluation of qwen_boxed type."""
if not self.raw_task or "question" not in self.raw_task:
raise ValueError("Raw task data must contain a 'question' field for MathEvalWorkflow.")

problem_input = self.raw_task["question"]

system_prompt = "You are a helpful assistant."
user_prompt = f"{problem_input}\nPlease reason step by step, and put your final answer within \\boxed{{}}."

messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
return messages

def run(self) -> List[Experience]:
messages = self.format_messages()

responses: List[Experience] = self.model.chat(messages, **self.eval_gen_args)

for response in responses:
if response.response_text is None or self.task.truth is None:
continue

accuracy, _ = verify_math_answer(
response_text=response.response_text, ground_truth=self.task.truth
)

acc_metrics = {"accuracy": accuracy}
if response.metrics is None:
response.metrics = {}
response.metrics.update(acc_metrics)

return responses
6 changes: 6 additions & 0 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ def eval(self):
self.logger.warning("No evaluation data samples. Skip evaluation.")
return
self.logger.info(f"Evaluation at step {self.explore_step_num} started.")

if self.config.buffer.explorer_input.default_eval_workflow_type:
self.logger.info(
f"Use '{self.config.buffer.explorer_input.default_eval_workflow_type}' for evaluation."
)

for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.logger.info(
f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started."
Expand Down
9 changes: 7 additions & 2 deletions trinity/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def beginner_mode(self):
if st.session_state["sft_warmup_steps"] > 0:
self.get_configs("sft_warmup_dataset_args")

self.get_configs("default_workflow_type", "default_reward_fn_type")
self.get_configs(
"default_workflow_type", "default_eval_workflow_type", "default_reward_fn_type"
)

self.get_configs(
"actor_ppo_micro_batch_size_per_gpu",
Expand All @@ -166,7 +168,9 @@ def _expert_model_part(self):
def _expert_buffer_part(self):
self.get_configs("total_epochs", "train_batch_size")

self.get_configs("default_workflow_type", "default_reward_fn_type")
self.get_configs(
"default_workflow_type", "default_eval_workflow_type", "default_reward_fn_type"
)
self.get_configs("system_prompt")
self.get_configs("reply_prefix")

Expand Down Expand Up @@ -544,6 +548,7 @@ def _gen_buffer_config(self):
},
"eval_tasksets": [],
"default_workflow_type": st.session_state["default_workflow_type"],
"default_eval_workflow_type": st.session_state["default_eval_workflow_type"],
"default_reward_fn_type": st.session_state["default_reward_fn_type"],
"system_prompt": st.session_state["system_prompt"],
"reply_prefix": st.session_state["reply_prefix"],
Expand Down
Loading