Skip to content

Commit fa53d98

Browse files
committed
Add the implementation of evaluation workflow
1 parent 8d6daea commit fa53d98

File tree

13 files changed

+805
-138
lines changed

13 files changed

+805
-138
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ dev = [
6666
"pytest>=8.0.0",
6767
"pytest-json-ctrf",
6868
"parameterized",
69-
"matplotlib"
69+
"matplotlib",
70+
"word2number",
7071
]
7172

7273
doc = [

tests/explorer/workflow_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from trinity.common.rewards import RMGalleryFn
1212
from trinity.common.workflows import (
1313
MathBoxedWorkflow,
14+
MathEvalWorkflow,
1415
MathRMWorkflow,
1516
MathWorkflow,
1617
Workflow,
@@ -272,6 +273,36 @@ def test_rm_gallery_workflow(self) -> None:
272273
self.assertEqual(experiences[2].reward, 1.0)
273274
self.assertEqual(experiences[3].reward, 0.0)
274275

276+
def test_math_eval_workflow(self) -> None:
277+
model = MagicMock()
278+
model.chat.return_value = [
279+
MockResponse("My step-by-step reasoning leads to the answer \boxed{36}"),
280+
MockResponse("Here is the answer of \boxed{36.0}"),
281+
MockResponse("I made a mistake, the answer is \boxed{42}"),
282+
MockResponse("The answer is 36, but I forgot the box."),
283+
]
284+
285+
taskset_config = get_unittest_dataset_config("countdown")
286+
task = Task(
287+
workflow=MathEvalWorkflow,
288+
is_eval=True,
289+
format_args=taskset_config.format,
290+
raw_task={
291+
taskset_config.format.prompt_key: "",
292+
taskset_config.format.response_key: "36",
293+
},
294+
)
295+
296+
workflow = task.to_workflow(model=model)
297+
experiences = workflow.run()
298+
self.assertEqual(len(experiences), 4)
299+
expected_accuracies = [1.0, 1.0, 0.0, 0.0]
300+
for i, (exp, expected_acc) in enumerate(zip(experiences, expected_accuracies)):
301+
with self.subTest(f"Response {i}"):
302+
self.assertEqual(exp.reward, 0.0)
303+
assert exp.metrics is not None, f"Metrics for response {i} should not be None"
304+
self.assertEqual(exp.metrics["accuracy"], expected_acc)
305+
275306
def test_workflow_resettable(self) -> None:
276307
model = MagicMock()
277308
json_task = Task(

tests/template/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ buffer:
3434
path: 'placeholder'
3535
split: 'train'
3636
default_workflow_type: ''
37+
default_eval_type: ''
3738
default_reward_fn_type: ''
3839
explorer:
3940
eval_interval: 100

tests/test_data/template.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ buffer:
1111
storage_type: file
1212
path: ''
1313
default_workflow_type: ''
14+
default_eval_type: ''
1415
default_reward_fn_type: ''
1516
explorer:
1617
runner_num: 8

tests/utils/eval_utils_test.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# -*- coding: utf-8 -*-
2+
"""Test for the evaluation utils module."""
3+
4+
import unittest
5+
6+
from trinity.utils.eval_utils import is_equiv
7+
from trinity.utils.math_eval_utils import extract_answer, verify_math_answer
8+
9+
10+
class TestMathEvalUtils(unittest.TestCase):
11+
def test_extract_answer(self):
12+
test_cases = [
13+
("The answer is \\boxed{42}", "42", "Basic boxed extraction"),
14+
("The result is \\boxed{\\frac{1}{2}}", "\\frac{1}{2}", "Boxed with LaTeX"),
15+
("Therefore, the final answer is 100.", "100", "English 'answer is' extraction"),
16+
("My final answer is: 3.14", "3.14", "English 'answer is' with colon"),
17+
("所以,答案是x^2", "x^2", "Chinese 'answer is' extraction"),
18+
(
19+
"The cost is 10 dollars and the profit is 20 dollars.",
20+
"20",
21+
"Extract the last number",
22+
),
23+
(
24+
"There are 1,000 apples and 2,000 oranges.",
25+
"2000",
26+
"Extract the last number with commas",
27+
),
28+
("The probability is 0.75.", "0.75", "Extract the last decimal"),
29+
("This sentence has no answer.", None, "No answer case"),
30+
("The box is empty \\boxed{}", None, "Empty boxed"),
31+
(12345, None, "Input is not a string"),
32+
]
33+
34+
for i, (input_str, expected_output, description) in enumerate(test_cases):
35+
with self.subTest(f"Case {i+1}: {description}"):
36+
actual_output = extract_answer(input_str)
37+
self.assertEqual(
38+
actual_output,
39+
expected_output,
40+
"Failed on input: '{input_str}'\nExpected: '{expected_output}', Got: '{actual_output}'",
41+
)
42+
43+
def test_verify_math_answer(self):
44+
test_cases = [
45+
("The answer is \\boxed{42}", "42", True, "Simple integer equality"),
46+
("The result is 1,000.", "1000", True, "Number with commas"),
47+
("The answer is -50.", "-50", True, "Negative number equality"),
48+
("The solution is 5", "x=5", True, "Equivalence of value and equation"),
49+
("The answer is \\boxed{42}", "43", False, "Simple numerical inequality"),
50+
("The answer is \\boxed{x+1}", "x-1", False, "Symbolic expression inequality"),
51+
(
52+
"The matrix is \\boxed{\\begin{pmatrix}1 & 1 \\\\ 0 & 1\\end{pmatrix}}",
53+
"\\begin{pmatrix}1&0\\\\0&1\\end{pmatrix}",
54+
False,
55+
"Matrix inequality",
56+
),
57+
("The speed is 50 km/h", "50", True, "Judgment after stripping units"),
58+
]
59+
60+
for i, (response, ground_truth, expected_correct, description) in enumerate(test_cases):
61+
with self.subTest(f"Case {i+1}: {description}"):
62+
accuracy, details = verify_math_answer(response, ground_truth)
63+
is_correct = accuracy == 1.0
64+
self.assertEqual(
65+
is_correct,
66+
expected_correct,
67+
f"Failed on response: '{response}' with truth: '{ground_truth}'\n"
68+
f"Expected correct: {expected_correct}, Got: {is_correct}\nDetails: {details}",
69+
)
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()
74+
75+
76+
class TestEvalUtils(unittest.TestCase):
77+
def test_is_equiv(self):
78+
test_cases = [
79+
# str1, str2, expected_output, description
80+
(" 123 ", "123", True, "Equivalence with whitespace"),
81+
("50%", "50", True, "Equivalence with percentage sign"),
82+
("$50", "50", True, "Equivalence with dollar sign"),
83+
("hello", "world", False, "Basic inequality"),
84+
("123", "1234", False, "Numerical inequality"),
85+
(None, None, True, "Both inputs are None"),
86+
("Some string", None, False, "One input is None (str1)"),
87+
(None, "Some string", False, "One input is None (str2)"),
88+
]
89+
90+
for i, (str1, str2, expected_output, description) in enumerate(test_cases):
91+
with self.subTest(f"Case {i+1}: {description}"):
92+
actual_output = is_equiv(str1, str2)
93+
self.assertEqual(
94+
actual_output,
95+
expected_output,
96+
f"Failed on inputs: ('{str1}', '{str2}')\nExpected: {expected_output}, Got: {actual_output}",
97+
)

trinity/buffer/reader/file_reader.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
287287

288288
self.task_type = meta.task_type
289289
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore
290+
self.default_eval_workflow_cls = None
291+
if getattr(meta, "default_eval_type", None):
292+
self.default_eval_workflow_cls = WORKFLOWS.get(meta.default_eval_type)
290293
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore
291294

292295
def read(
@@ -296,11 +299,14 @@ def read(
296299
tasks = []
297300
samples = self.dataset.read_batch(batch_size)
298301
for sample in samples:
299-
workflow_class = (
300-
WORKFLOWS.get(sample[self.workflow_key])
301-
if self.workflow_key in sample
302-
else self.default_workflow_cls
303-
)
302+
if self.task_type == TaskType.EVAL and self.default_eval_workflow_cls:
303+
workflow_class = self.default_eval_workflow_cls
304+
else:
305+
workflow_class = (
306+
WORKFLOWS.get(sample[self.workflow_key])
307+
if self.workflow_key in sample
308+
else self.default_workflow_cls
309+
)
304310
reward_fn = (
305311
REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
306312
if self.reward_fn_key in sample

trinity/common/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class StorageConfig:
9797

9898
# used for rollout tasks
9999
default_workflow_type: Optional[str] = None
100+
default_eval_type: Optional[str] = None
100101
default_reward_fn_type: Optional[str] = None
101102
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
102103
workflow_args: dict = field(default_factory=dict)
@@ -275,6 +276,7 @@ class ExplorerInput:
275276
eval_tasksets: List[StorageConfig] = field(default_factory=list)
276277
# The following args provide default values for the corresponding args in `taskset` and `eval_tasksets`
277278
default_workflow_type: Optional[str] = None
279+
default_eval_type: Optional[str] = None
278280
default_reward_fn_type: Optional[str] = None
279281
system_prompt: Optional[str] = None
280282
reply_prefix: Optional[str] = None
@@ -485,6 +487,10 @@ def _check_buffer(self) -> None: # noqa: C901
485487
self.buffer.explorer_input.taskset.default_workflow_type = (
486488
self.buffer.explorer_input.default_workflow_type
487489
)
490+
if self.buffer.explorer_input.taskset.default_eval_type is None:
491+
self.buffer.explorer_input.taskset.default_eval_type = (
492+
self.buffer.explorer_input.default_eval_type
493+
)
488494
if self.buffer.explorer_input.taskset.default_reward_fn_type is None:
489495
self.buffer.explorer_input.taskset.default_reward_fn_type = (
490496
self.buffer.explorer_input.default_reward_fn_type
@@ -510,6 +516,8 @@ def _check_buffer(self) -> None: # noqa: C901
510516
dataset.name = f"eval_taskset_{idx}"
511517
if dataset.default_workflow_type is None:
512518
dataset.default_workflow_type = self.buffer.explorer_input.default_workflow_type
519+
if dataset.default_eval_type is None:
520+
dataset.default_eval_type = self.buffer.explorer_input.default_eval_type
513521
if dataset.default_reward_fn_type is None:
514522
dataset.default_reward_fn_type = self.buffer.explorer_input.default_reward_fn_type
515523
if dataset.format.system_prompt is None:

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
66
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
77
from .envs.webshop.webshop_workflow import WebShopWorkflow
8+
from .eval_workflow import MathEvalWorkflow
89
from .math_rm_workflow import MathRMWorkflow
910
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow
1011

@@ -20,4 +21,5 @@
2021
"MathBoxedWorkflow",
2122
"MathRMWorkflow",
2223
"ToolCallWorkflow",
24+
"MathEvalWorkflow",
2325
]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
"""Evaluation Workflow Class"""
3+
4+
from dataclasses import asdict
5+
from typing import List, Optional
6+
7+
import openai
8+
9+
from trinity.common.config import GenerationConfig
10+
from trinity.common.experience import Experience
11+
from trinity.common.models.model import ModelWrapper
12+
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
13+
from trinity.utils.log import get_logger
14+
from trinity.utils.math_eval_utils import verify_math_answer
15+
16+
logger = get_logger(__name__)
17+
18+
19+
@WORKFLOWS.register_module("math_eval_workflow")
20+
class MathEvalWorkflow(Workflow):
21+
"""
22+
A workflow for standard math evaluation.
23+
24+
The evaluation standard and prompting style are follow the Qwen2.5-Math
25+
model's evaluation methodology. For more details on their approach, see:
26+
https://github.com/QwenLM/Qwen2.5-Math
27+
"""
28+
29+
def __init__(
30+
self,
31+
*,
32+
task: Task,
33+
model: ModelWrapper,
34+
auxiliary_models: Optional[List[openai.OpenAI]] = None,
35+
):
36+
super().__init__(
37+
task=task,
38+
model=model,
39+
auxiliary_models=auxiliary_models,
40+
)
41+
42+
self.raw_task = task.raw_task
43+
self.truth = task.truth
44+
45+
# TODO: customize the config in the yaml
46+
self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1))
47+
48+
@property
49+
def resettable(self):
50+
return False
51+
52+
def format_messages(self):
53+
"""Format message for the evaluation of qwen_boxed type."""
54+
if not self.raw_task or "question" not in self.raw_task:
55+
raise ValueError("Raw task data must contain a 'question' field for MathEvalWorkflow.")
56+
57+
problem_input = self.raw_task["question"]
58+
59+
system_prompt = "You are a helpful assistant."
60+
user_prompt = f"{problem_input}\nPlease reason step by step, and put your final answer within \\boxed{{}}."
61+
62+
messages = [
63+
{"role": "system", "content": system_prompt},
64+
{"role": "user", "content": user_prompt},
65+
]
66+
return messages
67+
68+
def run(self) -> List[Experience]:
69+
messages = self.format_messages()
70+
71+
responses: List[Experience] = self.model.chat(messages, **self.eval_gen_args)
72+
73+
for response in responses:
74+
accuracy, eval_details = verify_math_answer(
75+
response_text=response.response_text, ground_truth=self.task.truth
76+
)
77+
78+
acc_metrics = {"accuracy": accuracy}
79+
if response.metrics is None:
80+
response.metrics = {}
81+
response.metrics.update(acc_metrics)
82+
83+
return responses

trinity/explorer/explorer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,12 @@ def eval(self):
271271
self.logger.warning("No evaluation data samples. Skip evaluation.")
272272
return
273273
self.logger.info(f"Evaluation at step {self.explore_step_num} started.")
274+
275+
if self.config.buffer.explorer_input.default_eval_type:
276+
self.logger.info(
277+
f"Use the evaluation: '{self.config.buffer.explorer_input.default_eval_type}'."
278+
)
279+
274280
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
275281
self.logger.info(
276282
f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started."

0 commit comments

Comments
 (0)