Skip to content

Commit 63d4920

Browse files
authored
Refactor RewardFn (#118)
1 parent d1c56dc commit 63d4920

File tree

18 files changed

+444
-314
lines changed

18 files changed

+444
-314
lines changed

examples/grpo_math/math.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ buffer:
2323
prompt_key: 'question'
2424
response_key: 'gt_answer'
2525
rollout_args:
26-
n: 8
2726
temperature: 1.0
2827
logprobs: 0
29-
default_workflow_type: 'math_workflow'
28+
reward_fn_args:
29+
reward_name: math_verify_reward
30+
default_workflow_type: 'math_rm_workflow'
31+
default_reward_fn_type: 'rm_gallery_reward'
3032
trainer_input:
3133
experience_buffer:
3234
name: math_buffer

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ data = [
5353
agent = [
5454
"agentscope"
5555
]
56+
rm_gallery = [
57+
"rm-gallery"
58+
]
5659
dev = [
5760
"pre-commit>=2.17.0",
5861
"black>=23.7.0",

tests/explorer/workflow_test.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,26 @@
22
"""Test for the workflow module"""
33
import unittest
44
from dataclasses import dataclass
5+
from typing import Dict, Optional
56
from unittest.mock import MagicMock
67

78
from tests.tools import get_unittest_dataset_config
8-
from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow
9+
from trinity.common.rewards import RMGalleryFn
10+
from trinity.common.workflows import (
11+
MathBoxedWorkflow,
12+
MathRMWorkflow,
13+
MathWorkflow,
14+
Workflow,
15+
)
916
from trinity.common.workflows.workflow import Task
1017

1118

1219
@dataclass
1320
class MockResponse:
1421
response_text: str
1522
reward: float = 0.0
23+
metrics: Optional[Dict[str, float]] = None
24+
info: Optional[Dict] = None
1625

1726

1827
class DummyWorkflow(Workflow):
@@ -206,7 +215,6 @@ def test_gsm8k_workflow(self) -> None:
206215
)
207216
workflow = task.to_workflow(model=model)
208217
experiences = workflow.run()
209-
# self.assertEqual(len(experiences), 1)
210218
self.assertEqual(experiences[0].reward, 1.1)
211219
self.assertEqual(experiences[1].reward, 0.9)
212220
self.assertEqual(experiences[2].reward, 0.9)
@@ -229,6 +237,37 @@ def test_gsm8k_workflow(self) -> None:
229237
self.assertEqual(experiences[2].reward, -0.1)
230238
self.assertEqual(experiences[3].reward, 1.1)
231239

240+
@unittest.skip("Skip for now, need to fix import issues of RM-Gallery")
241+
def test_rm_gallery_workflow(self) -> None:
242+
model = MagicMock()
243+
model.chat.return_value = [
244+
MockResponse("<think> balabalabala 99 </think>\n \\boxed{36}"),
245+
MockResponse("answer is \\boxed{36 }"),
246+
MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
247+
MockResponse("<think> balalaba </think> \\boxed{35.00}"),
248+
]
249+
taskset_config = get_unittest_dataset_config("countdown")
250+
task = Task(
251+
workflow=MathRMWorkflow,
252+
reward_fn=RMGalleryFn,
253+
format_args=taskset_config.format,
254+
rollout_args=taskset_config.rollout_args,
255+
reward_fn_args={
256+
"reward_name": "math_verify_reward",
257+
},
258+
is_eval=False,
259+
raw_task={
260+
taskset_config.format.prompt_key: "",
261+
taskset_config.format.response_key: r"36",
262+
},
263+
)
264+
workflow = task.to_workflow(model=model)
265+
experiences = workflow.run()
266+
self.assertEqual(experiences[0].reward, 1.0)
267+
self.assertEqual(experiences[1].reward, 1.0)
268+
self.assertEqual(experiences[2].reward, 1.0)
269+
self.assertEqual(experiences[3].reward, 0.0)
270+
232271
def test_workflow_resettable(self) -> None:
233272
model = MagicMock()
234273
json_task = Task(

trinity/buffer/reader/file_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def read(
294294
format_args=self.meta.format,
295295
rollout_args=self.meta.rollout_args,
296296
workflow_args=self.meta.workflow_args,
297+
reward_fn_args=self.meta.reward_fn_args,
297298
is_eval=self.meta.task_type == TaskType.EVAL,
298299
reward_fn=reward_fn,
299300
raw_task=sample,

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class StorageConfig:
9999
default_reward_fn_type: Optional[str] = None
100100
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
101101
workflow_args: dict = field(default_factory=dict)
102+
reward_fn_args: dict = field(default_factory=dict)
102103

103104
# get storage from existing experiment
104105
ray_namespace: Optional[str] = None

trinity/common/rewards/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
# -*- coding: utf-8 -*-
22
"""Reward functions for RFT"""
33

4-
from .reward_fn import REWARD_FUNCTIONS, AccuracyReward, FormatReward, RewardFn
4+
# isort: off
5+
from .reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn
6+
7+
from .accuracy_reward import AccuracyReward
8+
from .countdown_reward import CountDownRewardFn
9+
from .format_reward import FormatReward
10+
from .math_reward import MathBoxedRewardFn, MathRewardFn
11+
12+
# isort: on
513

614
__all__ = [
715
"RewardFn",
16+
"RMGalleryFn",
817
"REWARD_FUNCTIONS",
918
"AccuracyReward",
19+
"CountDownRewardFn",
1020
"FormatReward",
21+
"MathRewardFn",
22+
"MathBoxedRewardFn",
1123
]
Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,68 @@
1-
from typing import Any, Callable, Dict, List
1+
# -*- coding: utf-8 -*-
2+
"""Accuracy Reward Function Class."""
3+
from typing import Callable, Optional
24

3-
from .base import RewardShapper
5+
from latex2sympy2_extended import NormalizationConfig
6+
from math_verify import LatexExtractionConfig, parse, verify
47

8+
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
9+
from trinity.utils.log import get_logger
510

6-
class AccuracyRewardShapper(RewardShapper):
7-
"""Shapper for accuracy-based rewards"""
11+
logger = get_logger(__name__)
812

9-
def __init__(
10-
self,
11-
answer_parser: Callable[[str], str],
12-
correct_reward: float = 1.0,
13-
incorrect_reward: float = 0.0,
14-
kwargs: Dict[str, Any] = {},
15-
):
13+
14+
@REWARD_FUNCTIONS.register_module("accuracy_reward")
15+
class AccuracyReward(RewardFn):
16+
"""A reward function that rewards correct answers.
17+
Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
18+
"""
19+
20+
def __init__(self, answer_parser: Optional[Callable[[str], str]] = None):
1621
self.answer_parser = answer_parser
17-
self.correct_reward = correct_reward
18-
self.incorrect_reward = incorrect_reward
19-
self.response_key = kwargs.get("response", "response")
20-
self.truth_key = kwargs.get("ground_truth", "ground_truth")
2122

22-
def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
23-
response = sample[self.response_key]
24-
truth = sample[self.truth_key]
23+
def __call__( # type: ignore
24+
self,
25+
response: str,
26+
prompt: Optional[str] = None,
27+
truth: Optional[str] = None,
28+
) -> dict[str, float]:
29+
if self.answer_parser:
30+
answer_parsed = self.answer_parser(response)
31+
truth_parsed = self.answer_parser(truth) # type: ignore [arg-type]
2532

26-
parsed_response = self.answer_parser(response)
27-
reward = self.correct_reward if parsed_response == truth else self.incorrect_reward
33+
else:
34+
truth_parsed = parse(
35+
truth,
36+
extraction_mode="first_match",
37+
extraction_config=[LatexExtractionConfig()],
38+
)
39+
if len(truth_parsed) == 0:
40+
truth_parsed = truth
2841

29-
sample["accuracy_reward"] = reward
30-
return sample
42+
answer_parsed = parse(
43+
response,
44+
extraction_config=[
45+
LatexExtractionConfig(
46+
normalization_config=NormalizationConfig(
47+
nits=False,
48+
malformed_operators=False,
49+
basic_latex=True,
50+
equations=True,
51+
boxed="all",
52+
units=True,
53+
),
54+
# Ensures that boxed is tried first
55+
boxed_match_priority=0,
56+
try_extract_without_anchor=False,
57+
)
58+
],
59+
extraction_mode="first_match",
60+
)
3161

32-
def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
33-
return [self.shape(sample) for sample in samples]
62+
# Reward 1 if the content is the same as the ground truth, 0 otherwise
63+
try:
64+
reward = float(verify(answer_parsed, truth_parsed))
65+
except Exception as e:
66+
logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}")
67+
reward = 0.0
68+
return {"accuracy": reward}

trinity/common/rewards/base.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

trinity/common/rewards/composite_reward.py

Lines changed: 0 additions & 24 deletions
This file was deleted.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Base Reward Function Class."""
2+
import json
3+
from typing import Optional
4+
5+
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
6+
from trinity.utils.eval_utils import (
7+
evaluate_equation,
8+
extract_solution,
9+
validate_equation,
10+
)
11+
from trinity.utils.log import get_logger
12+
13+
logger = get_logger(__name__)
14+
15+
16+
@REWARD_FUNCTIONS.register_module("countdown_reward")
17+
class CountDownRewardFn(RewardFn):
18+
"""A reward function that rewards for countdown task.
19+
Ref: Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
20+
"""
21+
22+
def __init__(self):
23+
pass
24+
25+
def __call__( # type: ignore
26+
self,
27+
response: str,
28+
prompt: Optional[str] = None,
29+
truth: Optional[str] = None,
30+
) -> dict[str, float]:
31+
truth = json.loads(truth) # type: ignore
32+
target = truth["target"] # type: ignore
33+
numbers = truth["numbers"] # type: ignore
34+
35+
solution_str = response
36+
equation = extract_solution(solution_str=solution_str)
37+
format_score = 0.1
38+
score = 1.0
39+
40+
if equation is None:
41+
return {"score": 0}
42+
43+
# Validate equation uses correct numbers
44+
if not validate_equation(equation, numbers):
45+
return {"score": format_score}
46+
47+
# Evaluate equation
48+
try:
49+
result = evaluate_equation(equation)
50+
if result is None:
51+
return {"score": format_score}
52+
53+
if abs(result - target) < 1e-5: # Account for floating point precision
54+
return {"score": score}
55+
else:
56+
return {"score": format_score}
57+
except Exception as e: # noqa: F841
58+
return {"score": format_score}

0 commit comments

Comments
 (0)