forked from agentscope-ai/Trinity-RFT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcountdown_reward.py
More file actions
58 lines (47 loc) · 1.67 KB
/
countdown_reward.py
File metadata and controls
58 lines (47 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Base Reward Function Class."""
import json
from typing import Optional
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.eval_utils import (
evaluate_equation,
extract_solution,
validate_equation,
)
from trinity.utils.log import get_logger
logger = get_logger(__name__)
@REWARD_FUNCTIONS.register_module("countdown_reward")
class CountDownRewardFn(RewardFn):
"""A reward function that rewards for countdown task.
Ref: Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
"""
def __init__(self):
pass
def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
) -> dict[str, float]:
truth = json.loads(truth) # type: ignore
target = truth["target"] # type: ignore
numbers = truth["numbers"] # type: ignore
solution_str = response
equation = extract_solution(solution_str=solution_str)
format_score = 0.1
score = 1.0
if equation is None:
return {"score": 0}
# Validate equation uses correct numbers
if not validate_equation(equation, numbers):
return {"score": format_score}
# Evaluate equation
try:
result = evaluate_equation(equation)
if result is None:
return {"score": format_score}
if abs(result - target) < 1e-5: # Account for floating point precision
return {"score": score}
else:
return {"score": format_score}
except Exception as e: # noqa: F841
return {"score": format_score}