Skip to content

Commit cecbaed

Browse files
authored
Refactor and Merge several compute_score functions (#430)
1 parent e159707 commit cecbaed

File tree

13 files changed

+79
-540
lines changed

13 files changed

+79
-540
lines changed

benchmark/plugins/guru_math/reward.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def __call__( # type: ignore
1414
format_score_coef: Optional[float] = 0.1,
1515
**kwargs,
1616
) -> dict[str, float]:
17-
from .naive_dapo import compute_score
17+
from trinity.common.rewards.naive_dapo_score import compute_score
1818

19-
ret = compute_score(response, truth, None) # type: ignore
20-
return {"accuracy": ret["score"], "format_score": 0}
19+
score = compute_score(response, truth) # type: ignore
20+
return {"accuracy": score, "format_score": 0}
2121

2222

2323
@REWARD_FUNCTIONS.register_module("math_boxed_reward_prime_math")
@@ -32,5 +32,5 @@ def __call__( # type: ignore
3232
) -> dict[str, float]:
3333
from verl.utils.reward_score.prime_math import compute_score
3434

35-
ret = compute_score(response, truth)
36-
return {"accuracy": ret["score"], "format_score": 0}
35+
res = compute_score(response, truth)
36+
return {"accuracy": res["score"], "format_score": 0}

examples/bots/workflow/bots_math_boxed_reward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

3+
from trinity.common.rewards.eval_utils import validate_think_pattern
34
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
4-
from trinity.utils.eval_utils import validate_think_pattern
55

66

77
@REWARD_FUNCTIONS.register_module("bots_math_boxed_reward")
@@ -22,9 +22,9 @@ def __call__( # type: ignore
2222
format_score_coef: Optional[float] = 0.1,
2323
**kwargs,
2424
) -> dict[str, float]:
25-
from trinity.plugins.bots_reward import compute_score
25+
from trinity.plugins.bots_reward import compute_score_bots
2626

27-
accuracy_score = compute_score(response, truth)
27+
accuracy_score = compute_score_bots(response, truth)
2828

2929
format_score = 0.0
3030
if with_think and not validate_think_pattern(response):

0 commit comments

Comments
 (0)