forked from agentscope-ai/Trinity-RFT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaccuracy_reward.py
More file actions
68 lines (58 loc) · 2.38 KB
/
accuracy_reward.py
File metadata and controls
68 lines (58 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-
"""Accuracy Reward Function Class."""
from typing import Callable, Optional
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.log import get_logger
logger = get_logger(__name__)
@REWARD_FUNCTIONS.register_module("accuracy_reward")
class AccuracyReward(RewardFn):
"""A reward function that rewards correct answers.
Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
"""
def __init__(self, answer_parser: Optional[Callable[[str], str]] = None):
self.answer_parser = answer_parser
def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
) -> dict[str, float]:
if self.answer_parser:
answer_parsed = self.answer_parser(response)
truth_parsed = self.answer_parser(truth) # type: ignore [arg-type]
else:
truth_parsed = parse(
truth,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(truth_parsed) == 0:
truth_parsed = truth
answer_parsed = parse(
response,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
reward = float(verify(answer_parsed, truth_parsed))
except Exception as e:
logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}")
reward = 0.0
return {"accuracy": reward}