diff --git a/benchmark/plugins/guru_math/reward.py b/benchmark/plugins/guru_math/reward.py index 9cc22d9c73..d30d60de44 100644 --- a/benchmark/plugins/guru_math/reward.py +++ b/benchmark/plugins/guru_math/reward.py @@ -14,10 +14,10 @@ def __call__( # type: ignore format_score_coef: Optional[float] = 0.1, **kwargs, ) -> dict[str, float]: - from .naive_dapo import compute_score + from trinity.common.rewards.naive_dapo_score import compute_score - ret = compute_score(response, truth, None) # type: ignore - return {"accuracy": ret["score"], "format_score": 0} + score = compute_score(response, truth) # type: ignore + return {"accuracy": score, "format_score": 0} @REWARD_FUNCTIONS.register_module("math_boxed_reward_prime_math") @@ -32,5 +32,5 @@ def __call__( # type: ignore ) -> dict[str, float]: from verl.utils.reward_score.prime_math import compute_score - ret = compute_score(response, truth) - return {"accuracy": ret["score"], "format_score": 0} + res = compute_score(response, truth) + return {"accuracy": res["score"], "format_score": 0} diff --git a/examples/bots/workflow/bots_math_boxed_reward.py b/examples/bots/workflow/bots_math_boxed_reward.py index a7890f8584..c49c5a36a2 100644 --- a/examples/bots/workflow/bots_math_boxed_reward.py +++ b/examples/bots/workflow/bots_math_boxed_reward.py @@ -1,7 +1,7 @@ from typing import Optional +from trinity.common.rewards.eval_utils import validate_think_pattern from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn -from trinity.utils.eval_utils import validate_think_pattern @REWARD_FUNCTIONS.register_module("bots_math_boxed_reward") @@ -22,9 +22,9 @@ def __call__( # type: ignore format_score_coef: Optional[float] = 0.1, **kwargs, ) -> dict[str, float]: - from trinity.plugins.bots_reward import compute_score + from trinity.plugins.bots_reward import compute_score_bots - accuracy_score = compute_score(response, truth) + accuracy_score = compute_score_bots(response, truth) format_score = 0.0 if with_think and not validate_think_pattern(response): diff --git a/examples/bots/workflow/bots_reward.py b/examples/bots/workflow/bots_reward.py index e4bdf4b98e..1664b6f1b6 100644 --- a/examples/bots/workflow/bots_reward.py +++ b/examples/bots/workflow/bots_reward.py @@ -1,20 +1,23 @@ # Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py -import concurrent import contextlib import math import re -import resource from math import isclose from typing import Optional, Union -import sympy -from pylatexenc import latex2text from sympy import N, simplify -from sympy.parsing import sympy_parser from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr from verl.utils.py_functional import timeout_limit +from verl.utils.reward_score.prime_math.grader import math_equal as verl_math_equal + +from trinity.common.rewards.eval_utils import remove_right_units +from trinity.common.rewards.naive_dapo_score import grade_answer, match_answer +from trinity.common.rewards.qwen25_eval import fix_fracs + +_fix_fracs = fix_fracs +_remove_right_units = remove_right_units def handle_base(x): @@ -303,362 +306,6 @@ def math_equal( # noqa return symbolic_equal(prediction, reference, tolerance, timeout) -# Constants for normalization -SUBSTITUTIONS = [ - ("an ", ""), - ("a ", ""), - (".$", "$"), - ("\\$", ""), - (r"\ ", ""), - (" ", ""), - ("mbox", "text"), - (",\\text{and}", ","), - ("\\text{and}", ","), - ("\\text{m}", "\\text{}"), -] - -REMOVED_EXPRESSIONS = [ - "square", - "ways", - "integers", - "dollars", - "mph", - "inches", - "hours", - "km", - "units", - "\\ldots", - "sue", - "points", - "feet", - "minutes", - "digits", - "cents", - "degrees", - "cm", - "gm", - "pounds", - "meters", - "meals", - "edges", - "students", - "childrentickets", - "multiples", - "\\text{s}", - "\\text{.}", - "\\text{\ns}", - "\\text{}^2", - "\\text{}^3", - "\\text{\n}", - "\\text{}", - r"\mathrm{th}", - r"^\circ", - r"^{\circ}", - r"\;", - r",\!", - "{,}", - '"', - "\\dots", -] - - -def normalize_final_answer(final_answer: str) -> str: - """Normalize a final answer to a quantitative reasoning question. - - Args: - final_answer: The answer string to normalize - - Returns: - Normalized answer string - """ - final_answer = final_answer.split("=")[-1] - - # Apply substitutions and removals - for before, after in SUBSTITUTIONS: - final_answer = final_answer.replace(before, after) - for expr in REMOVED_EXPRESSIONS: - final_answer = final_answer.replace(expr, "") - - # Extract and normalize LaTeX math - final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) - final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) - - # Normalize shorthand TeX: - # \fracab -> \frac{a}{b} - # \frac{abc}{bef} -> \frac{abc}{bef} - # \fracabc -> \frac{a}{b}c - # \sqrta -> \sqrt{a} - # \sqrtab -> sqrt{a}b - final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) - final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) - final_answer = final_answer.replace("$", "") - - # Normalize numbers - if final_answer.replace(",", "").isdigit(): - final_answer = final_answer.replace(",", "") - - return final_answer.strip() - - -# sympy might hang -- we don't care about trying to be lenient in these cases -BAD_SUBSTRINGS = ["^{", "^("] -BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"] -TUPLE_CHARS = "()[]" - - -def _sympy_parse(expr: str): - """Parses an expression with sympy.""" - py_expr = expr.replace("^", "**") - return sympy_parser.parse_expr( - py_expr, - transformations=( - sympy_parser.standard_transformations - + (sympy_parser.implicit_multiplication_application,) - ), - ) - - -def _parse_latex(expr: str) -> str: - """Attempts to parse latex to an expression sympy can read.""" - expr = expr.replace("\\tfrac", "\\frac") - expr = expr.replace("\\dfrac", "\\frac") - expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) - - # Replace the specific characters that this parser uses. - expr = expr.replace("√", "sqrt") - expr = expr.replace("π", "pi") - expr = expr.replace("∞", "inf") - expr = expr.replace("∪", "U") - expr = expr.replace("·", "*") - expr = expr.replace("×", "*") - - return expr.strip() - - -def _is_float(num: str) -> bool: - try: - float(num) - return True - except ValueError: - return False - - -def _is_int(x: float) -> bool: - try: - return abs(x - int(round(x))) <= 1e-7 - except Exception: - return False - - -def _is_frac(expr: str) -> bool: - return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) - - -def _str_is_int(x: str) -> bool: - try: - x = _strip_properly_formatted_commas(x) - return abs(float(x) - int(round(float(x)))) <= 1e-7 - except Exception: - return False - - -def _str_to_int(x: str) -> int: - x = x.replace(",", "") - x = float(x) - return int(x) - - -def _inject_implicit_mixed_number(step: str): - """ - Automatically make a mixed number evalable - e.g. 7 3/4 => 7+3/4 - """ - p1 = re.compile("([0-9]) +([0-9])") - step = p1.sub("\\1+\\2", step) # implicit mults - return step - - -def _strip_properly_formatted_commas(expr: str): - # We want to be careful because we don't want to strip tuple commas - p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)") - while True: - next_expr = p1.sub("\\1\\3\\4", expr) - if next_expr == expr: - break - expr = next_expr - return next_expr - - -def _normalize(expr: str) -> str: - """Normalize answer expressions.""" - if expr is None: - return None - - # Remove enclosing `\text{}`. - m = re.search(r"^\\\\text\{(?P.+?)\}$", expr) - if m is not None: - expr = m.group("text") - - expr = expr.replace("\\%", "%") - expr = expr.replace("\\$", "$") - expr = expr.replace("$", "") - expr = expr.replace("%", "") - expr = expr.replace(" or ", " , ") - expr = expr.replace(" and ", " , ") - - expr = expr.replace("million", "*10^6") - expr = expr.replace("billion", "*10^9") - expr = expr.replace("trillion", "*10^12") - - for unit in [ - "degree", - "cm", - "centimeter", - "meter", - "mile", - "second", - "minute", - "hour", - "day", - "week", - "month", - "year", - "foot", - "feet", - "inch", - "yard", - "liter", - ]: - expr = re.sub(r"{}(es)?(s)? *(\^[0-9]+)?".format(unit), "", expr) - expr = re.sub(r"\^ *\\\\circ", "", expr) - - if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": - expr = expr[1:-1] - - expr = re.sub(",\\\\! *", "", expr) - if _is_float(expr) and _is_int(float(expr)): - expr = str(int(round(float(expr)))) - if "\\" in expr: - try: - expr = _parse_latex(expr) - except Exception: - pass - - # edge case with mixed numbers and negative signs - expr = re.sub("- *", "-", expr) - - expr = _inject_implicit_mixed_number(expr) - - # don't be case sensitive for text answers - expr = expr.lower() - - if _str_is_int(expr): - expr = str(_str_to_int(expr)) - - return expr - - -def count_unknown_letters_in_expr(expr: str): - expr = expr.replace("sqrt", "") - expr = expr.replace("frac", "") - letters_in_expr = set([x for x in expr if x.isalpha()]) - return len(letters_in_expr) - - -def should_allow_eval(expr: str): - # we don't want to try parsing unknown text or functions of more than two variables - if count_unknown_letters_in_expr(expr) > 2: - return False - - for bad_string in BAD_SUBSTRINGS: - if bad_string in expr: - return False - - for bad_regex in BAD_REGEXES: - if re.search(bad_regex, expr) is not None: - return False - - return True - - -# @timeout(timeout_seconds=10) -def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): - def check_equal(): - memory_size = 1024**3 - resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size)) - - expr = f"({ground_truth_normalized})-({given_normalized})" - if should_allow_eval(expr): - sympy_diff = _sympy_parse(expr) - simplified = sympy.simplify(sympy_diff) - if simplified == 0: - return True - return False - - with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(check_equal) - try: - return future.result(timeout=10) - except (concurrent.futures.TimeoutError, Exception): - future.cancel() - return False - - -def split_tuple(expr: str): - """ - Split the elements in a tuple/interval, while handling well-formatted commas in large numbers - """ - expr = _strip_properly_formatted_commas(expr) - if len(expr) == 0: - return [] - if ( - len(expr) > 2 - and expr[0] in TUPLE_CHARS - and expr[-1] in TUPLE_CHARS - and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) - ): - elems = [elem.strip() for elem in expr[1:-1].split(",")] - else: - elems = [expr] - return elems - - -def _fix_fracs(string): - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except: # noqa: E722 - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string - - def _fix_a_slash_b(string): if len(string.split("/")) != 2: return string @@ -674,16 +321,6 @@ def _fix_a_slash_b(string): return string -def _remove_right_units(string): - # "\\text{ " only ever occurs (at least in the val set) when describing units - if "\\text{ " in string: - splits = string.split("\\text{ ") - assert len(splits) == 2 - return splits[0] - else: - return string - - def _fix_sqrt(string): if "\\sqrt" not in string: return string @@ -777,106 +414,7 @@ def normalize_answer(answer: Optional[str]) -> str: return answer -def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: - """ - The answer will be considered correct if: - (a) it normalizes to the same string as the ground truth answer - OR - (b) sympy can simplify the difference between the expressions to 0 - """ - if given_answer is None: - return False - - ground_truth_normalized_mathd = normalize_answer(ground_truth) - given_answer_normalized_mathd = normalize_answer(given_answer) - - # be at least as lenient as mathd - if ground_truth_normalized_mathd == given_answer_normalized_mathd: - return True, given_answer_normalized_mathd - - ground_truth_normalized = _normalize(ground_truth) - given_normalized = _normalize(given_answer) - - if ground_truth_normalized is None: - return False, given_normalized - - if ground_truth_normalized == given_normalized: - return True, given_normalized - - if len(given_normalized) == 0: - return False, given_normalized - - ground_truth_elems = split_tuple(ground_truth_normalized) - given_elems = split_tuple(given_normalized) - - if len(ground_truth_elems) > 1 and ( - ground_truth_normalized[0] != given_normalized[0] - or ground_truth_normalized[-1] != given_normalized[-1] - ): - is_correct = False - elif len(ground_truth_elems) != len(given_elems): - is_correct = False - else: - for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): - if _is_frac(ground_truth_elem) and _is_frac(given_elem): - # if fractions aren't reduced, then shouldn't be marked as correct - # so, we don't want to allow sympy.simplify in this case - is_correct = ground_truth_elem == given_elem - elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): - # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) - is_correct = False - else: - is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) - if not is_correct: - break - - return is_correct, given_normalized - - -def _last_boxed_only_string(string): - idx = string.rfind("\\boxed") - if idx < 0: - idx = string.rfind("\\fbox") - if idx < 0: - return None - - i = idx - left_brace_idx = None - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if left_brace_idx is None: - left_brace_idx = i - elif string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - - i += 1 - - if left_brace_idx is None or right_brace_idx is None: - return None - - return string[left_brace_idx + 1 : right_brace_idx].strip() - - -def match_answer(response): - is_matched = False - response = response.split("")[-1] - - # Find boxed - ans_boxed = _last_boxed_only_string(response) - if ans_boxed: - is_matched = True - response = ans_boxed - - return is_matched, response - - -def compute_score(solution_str: str, ground_truth: Optional[str]) -> float: +def compute_score_bots(solution_str: str, ground_truth: Optional[str]) -> float: """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions Args: @@ -901,19 +439,30 @@ def compute_score(solution_str: str, ground_truth: Optional[str]) -> float: if not correct: try: + # Use verl's math_equal for additional verification if "\\pi" in extracted_model_output or "\\pi" in ground_truth: + # Try with different pi values equivs = [] - for pi in [math.pi, 3.14]: - equivs.append( - math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi) - ) - correct = any(equivs) + for pi_val in [math.pi, 3.14]: + try: + equivs.append(verl_math_equal(extracted_model_output, ground_truth)) + except Exception: + # Fallback to local math_equal if verl's doesn't work + equivs.append( + math_equal( + extracted_model_output, ground_truth, timeout=True, pi=pi_val + ) + ) + correct = any(equivs) else: - correct = math_equal(extracted_model_output, ground_truth, timeout=True) + try: + correct = verl_math_equal(extracted_model_output, ground_truth) + except Exception: + # Fallback to local math_equal + correct = math_equal(extracted_model_output, ground_truth, timeout=True) except Exception: correct = False - # reward = 1.0 if correct else -1.0 reward = 1.0 if correct else 0.0 return reward diff --git a/tests/utils/eval_utils_test.py b/tests/utils/eval_utils_test.py index 533c04b836..cfef476fb6 100644 --- a/tests/utils/eval_utils_test.py +++ b/tests/utils/eval_utils_test.py @@ -3,8 +3,8 @@ import unittest -from trinity.utils.eval_utils import compute_score, is_equiv -from trinity.utils.math_eval_utils import extract_answer, verify_math_answer +from trinity.common.rewards.eval_utils import compute_score_v0, is_equiv +from trinity.common.rewards.qwen25_eval import extract_answer, verify_math_answer class TestComputeScore(unittest.TestCase): @@ -19,7 +19,7 @@ def test_both_boxed_and_equivalent(self): """ solution = "The final answer is \\boxed{42}" truth = "The correct result is \\boxed{42}" - self.assertEqual(compute_score(solution, truth), 1.0) + self.assertEqual(compute_score_v0(solution, truth), 1.0) def test_solution_raw_and_ground_truth_boxed_equivalent(self): """ @@ -28,7 +28,7 @@ def test_solution_raw_and_ground_truth_boxed_equivalent(self): """ solution = "The answer is \\boxed{42}" truth = "The answer is \\boxed{42}" - self.assertEqual(compute_score(solution, truth), 1.0) + self.assertEqual(compute_score_v0(solution, truth), 1.0) def test_solution_boxed_truth_raw_and_equivalent(self): """ @@ -37,7 +37,7 @@ def test_solution_boxed_truth_raw_and_equivalent(self): """ solution = "Let's see, the result is \\boxed{100}" truth = "100" - self.assertEqual(compute_score(solution, truth), 1.0) + self.assertEqual(compute_score_v0(solution, truth), 1.0) def test_both_boxed_and_not_equivalent(self): """ @@ -46,7 +46,7 @@ def test_both_boxed_and_not_equivalent(self): """ solution = "I think the answer is \\boxed{-1}" truth = "The answer is \\boxed{1}" - self.assertEqual(compute_score(solution, truth), 0.0) + self.assertEqual(compute_score_v0(solution, truth), 0.0) def test_solution_boxed_truth_raw_and_not_equivalent(self): """ @@ -55,7 +55,7 @@ def test_solution_boxed_truth_raw_and_not_equivalent(self): """ solution = "The answer is \\boxed{apple}" truth = "orange" - self.assertEqual(compute_score(solution, truth), 0.0) + self.assertEqual(compute_score_v0(solution, truth), 0.0) def test_solution_not_boxed(self): """ @@ -65,8 +65,8 @@ def test_solution_not_boxed(self): solution = "The answer is 42, but I'm not boxing it." truth_boxed = "The answer is \\boxed{42}" truth_raw = "42" - self.assertEqual(compute_score(solution, truth_boxed), 0.0) - self.assertEqual(compute_score(solution, truth_raw), 0.0) + self.assertEqual(compute_score_v0(solution, truth_boxed), 0.0) + self.assertEqual(compute_score_v0(solution, truth_raw), 0.0) def test_empty_solution_string(self): """ @@ -75,7 +75,7 @@ def test_empty_solution_string(self): """ solution = "" truth = "\\boxed{10}" - self.assertEqual(compute_score(solution, truth), 0.0) + self.assertEqual(compute_score_v0(solution, truth), 0.0) def test_empty_ground_truth(self): """ @@ -85,8 +85,8 @@ def test_empty_ground_truth(self): solution_correct = "The answer is \\boxed{}" solution_incorrect = "The answer is \\boxed{1}" truth = "" - self.assertEqual(compute_score(solution_correct, truth), 1.0) - self.assertEqual(compute_score(solution_incorrect, truth), 0.0) + self.assertEqual(compute_score_v0(solution_correct, truth), 1.0) + self.assertEqual(compute_score_v0(solution_incorrect, truth), 0.0) def test_multiple_boxed_answers_in_solution(self): """ @@ -95,8 +95,8 @@ def test_multiple_boxed_answers_in_solution(self): solution = "First I thought it was \\boxed{A}, but then I realized it is \\boxed{B}" truth_correct = "\\boxed{B}" truth_incorrect = "\\boxed{A}" - self.assertEqual(compute_score(solution, truth_correct), 1.0) - self.assertEqual(compute_score(solution, truth_incorrect), 0.0) + self.assertEqual(compute_score_v0(solution, truth_correct), 1.0) + self.assertEqual(compute_score_v0(solution, truth_incorrect), 0.0) class TestMathEvalUtils(unittest.TestCase): diff --git a/trinity/common/rewards/accuracy_reward.py b/trinity/common/rewards/accuracy_reward.py index ef75f18742..9bc1b14e5c 100644 --- a/trinity/common/rewards/accuracy_reward.py +++ b/trinity/common/rewards/accuracy_reward.py @@ -5,8 +5,8 @@ from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig +from trinity.common.rewards.eval_utils import parse_with_timeout, verify_with_timeout from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn -from trinity.utils.eval_utils import parse_with_timeout, verify_with_timeout from trinity.utils.log import get_logger diff --git a/trinity/common/rewards/countdown_reward.py b/trinity/common/rewards/countdown_reward.py index a267b629f6..1ec20ea68a 100644 --- a/trinity/common/rewards/countdown_reward.py +++ b/trinity/common/rewards/countdown_reward.py @@ -2,12 +2,12 @@ import json from typing import Optional -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn -from trinity.utils.eval_utils import ( +from trinity.common.rewards.eval_utils import ( evaluate_equation, extract_solution, validate_equation, ) +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn @REWARD_FUNCTIONS.register_module("countdown_reward") diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index a6b6023aa2..85438e4333 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -4,8 +4,8 @@ import torch +from trinity.common.rewards.naive_dapo_score import compute_score from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn -from trinity.utils.eval_utils import compute_score @REWARD_FUNCTIONS.register_module("math_dapo_reward") @@ -28,7 +28,7 @@ def __call__( # type: ignore self, response: str, response_token: torch.Tensor, - truth: Optional[str] = None, + truth: str, **kwargs, ) -> dict[str, float]: accuracy_score = compute_score(response, truth) diff --git a/trinity/utils/eval_utils.py b/trinity/common/rewards/eval_utils.py similarity index 95% rename from trinity/utils/eval_utils.py rename to trinity/common/rewards/eval_utils.py index eb05ca51c8..21af31a773 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/common/rewards/eval_utils.py @@ -4,7 +4,7 @@ import regex as re from math_verify import parse, verify -from trinity.utils.math_eval_utils import strip_string +from trinity.common.rewards.qwen25_eval import strip_string ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") INVALID_ANS = "[invalid]" @@ -125,7 +125,11 @@ def validate_think_pattern(text): # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py -def compute_score(solution_str, ground_truth) -> float: +def compute_score_v0(solution_str, ground_truth) -> float: + """ + Compute the score of the solution string against the ground truth. + This function suits easily-verifiable problems; the answer is put within `\boxed{}`. + """ retval = 0.0 try: string_in_last_boxed = last_boxed_only_string(solution_str) @@ -143,7 +147,6 @@ def compute_score(solution_str, ground_truth) -> float: ground_truth = original_ground_truth if is_equiv(answer, ground_truth): retval = 1.0 - # logger.warning(answer, " ", ground_truth, " ", retval) except Exception as e: print(e) @@ -186,6 +189,7 @@ def remove_boxed(s): # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py def last_boxed_only_string(string): + """Extracts the last `\boxed{...}` or `\fbox{...}` substring from the input string.""" idx = string.rfind("\\boxed") if "\\boxed " in string: return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] diff --git a/trinity/common/rewards/math_reward.py b/trinity/common/rewards/math_reward.py index 4772a54263..b6b65416f6 100644 --- a/trinity/common/rewards/math_reward.py +++ b/trinity/common/rewards/math_reward.py @@ -3,13 +3,13 @@ from typing import Optional from trinity.common.rewards.accuracy_reward import AccuracyReward -from trinity.common.rewards.format_reward import FormatReward -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn -from trinity.utils.eval_utils import ( - compute_score, +from trinity.common.rewards.eval_utils import ( + compute_score_v0, simple_answer_parser, validate_think_pattern, ) +from trinity.common.rewards.format_reward import FormatReward +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn @REWARD_FUNCTIONS.register_module("math_reward") @@ -58,7 +58,7 @@ def __call__( # type: ignore format_score_coef: Optional[float] = 0.1, **kwargs, ) -> dict[str, float]: - accuracy_score = compute_score(response, truth) + accuracy_score = compute_score_v0(response, truth) format_score = 0.0 if with_think and not validate_think_pattern(response): diff --git a/benchmark/plugins/guru_math/naive_dapo.py b/trinity/common/rewards/naive_dapo_score.py similarity index 93% rename from benchmark/plugins/guru_math/naive_dapo.py rename to trinity/common/rewards/naive_dapo_score.py index b1dce7c597..8f7ba2c73d 100644 --- a/benchmark/plugins/guru_math/naive_dapo.py +++ b/trinity/common/rewards/naive_dapo_score.py @@ -1,17 +1,7 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +""" +This file contaims the naive dapo reward function for math tasks. +Adapted from https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py +""" import concurrent import math @@ -434,6 +424,7 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: def _last_boxed_only_string(string): + """Strictly extract content from \boxed{}.""" idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") @@ -476,7 +467,7 @@ def match_answer(response): return is_matched, response -def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> dict: +def compute_score(solution_str: str, ground_truth: str) -> float: """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions Args: @@ -513,10 +504,4 @@ def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> dic except Exception: correct = False - reward = 1.0 if correct else 0.0 - acc = correct - - return { - "score": reward, - "acc": acc, - } + return 1.0 if correct else 0.0 diff --git a/trinity/utils/math_eval_utils.py b/trinity/common/rewards/qwen25_eval.py similarity index 100% rename from trinity/utils/math_eval_utils.py rename to trinity/common/rewards/qwen25_eval.py diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py index f723c0e347..2da3daa281 100644 --- a/trinity/common/workflows/customized_toolcall_workflows.py +++ b/trinity/common/workflows/customized_toolcall_workflows.py @@ -98,7 +98,7 @@ def validate_format(tool_call_list): # Adapted from https://github.com/NVlabs/Tool-N1 -def extract_solution_v0(tool_call_str): +def extract_solution_v0_toolN1(tool_call_str): output_string = tool_call_str pattern = r"(.*?)" @@ -112,14 +112,15 @@ def extract_solution_v0(tool_call_str): return None, output_string -def compute_score_v0( # noqa: C901 +# Adapted from https://github.com/NVlabs/Tool-N1 +def compute_score_v0_toolN1( # noqa: C901 solution_str, ground_truth, do_print=False, ): answer = json.loads(ground_truth) - result, output_string = extract_solution_v0(solution_str) + result, output_string = extract_solution_v0_toolN1(solution_str) if isinstance(result, str): try: @@ -199,7 +200,7 @@ def compute_toolcall_reward( solution_str: str, ground_truth: str, ) -> float: - res = compute_score_v0(solution_str, ground_truth) + res = compute_score_v0_toolN1(solution_str, ground_truth) if isinstance(res, (int, float, bool)): return float(res) else: diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py index 55ea78e74e..837bc12881 100644 --- a/trinity/common/workflows/eval_workflow.py +++ b/trinity/common/workflows/eval_workflow.py @@ -9,8 +9,8 @@ from trinity.common.config import GenerationConfig from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper +from trinity.common.rewards.qwen25_eval import verify_math_answer from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow -from trinity.utils.math_eval_utils import verify_math_answer @WORKFLOWS.register_module("math_eval_workflow")