Skip to content

Commit 3b31371

Browse files
wang2yn84The tunix Authors
authored andcommitted
[Tunix] Add special handling for math answer grading.
PiperOrigin-RevId: 889627979
1 parent 53d9c15 commit 3b31371

File tree

4 files changed

+272
-6
lines changed

4 files changed

+272
-6
lines changed

examples/deepscaler/math_eval_nb.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,11 @@ def evaluate_correctness(response: Any, ground_truths: Any) -> bool:
152152
return False
153153
# Check against all possible correct answers
154154
for ground_truth in processed_ground_truths:
155-
is_correct = math_utils.grade_answer_mathd(
156-
model_answer, ground_truth
157-
) or math_utils.grade_answer_sympy(model_answer, ground_truth)
155+
is_correct = (
156+
math_utils.grade_answer_mathd(model_answer, ground_truth)
157+
or math_utils.grade_answer_sympy(model_answer, ground_truth)
158+
or math_utils.grade_answer_special_handling(model_answer, ground_truth)
159+
)
158160
if is_correct:
159161
print(f" {model_answer=} {ground_truth=} IS CORRECT")
160162
return True

tests/utils/math_utils_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for tunix.utils.math_utils special handling."""
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
from tunix.utils import math_utils
20+
21+
22+
class MathUtilsSpecialHandlingTest(parameterized.TestCase):
23+
24+
@parameterized.named_parameters(
25+
dict(
26+
testcase_name="recurring_decimal_overlap",
27+
given_answer="16.67",
28+
ground_truth=r"16.\overline{6}",
29+
expected=True,
30+
),
31+
dict(
32+
testcase_name="recurring_decimal_all_single_digit_pattern",
33+
given_answer="2.33",
34+
ground_truth=r"2.\overline{3}",
35+
expected=True,
36+
),
37+
dict(
38+
testcase_name="recurring_decimal_all_single_digit_pattern2",
39+
given_answer="2.3",
40+
ground_truth=r"2.\overline{3}",
41+
expected=True,
42+
),
43+
dict(
44+
testcase_name="invalid_sqrt_cleanup_equivalent",
45+
given_answer=r"\frac{3\sqrt{3}}{2}",
46+
ground_truth=r"\frac{3\sqrt{}{3}}{2}",
47+
expected=True,
48+
),
49+
dict(
50+
testcase_name="interval_union_equivalence",
51+
given_answer=r"$-5\lex\le1$or$3\lex\le9$",
52+
ground_truth=r"[-5,1]\cup[3,9]",
53+
expected=True,
54+
),
55+
dict(
56+
testcase_name="partial_interval_not_tolerated",
57+
given_answer=r"$-5\lex\le1$or$3\lex\le9$",
58+
ground_truth=r"-5,1]\cup[3,9]",
59+
expected=False,
60+
),
61+
)
62+
def test_grade_answer_special_handling(
63+
self, given_answer: str, ground_truth: str, expected: bool
64+
):
65+
self.assertEqual(
66+
math_utils.grade_answer_special_handling(given_answer, ground_truth),
67+
expected,
68+
)
69+
70+
71+
if __name__ == "__main__":
72+
absltest.main()

tunix/utils/math_rewards.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,13 @@ def math_reward(prompts: List[str], completions: List[str], answer: List[str], *
8787
for ground_truth in processed_ground_truths:
8888
if found_correct_answer:
8989
break
90-
is_correct = math_utils.grade_answer_mathd(
91-
model_answer, ground_truth
92-
) or math_utils.grade_answer_sympy(model_answer, ground_truth)
90+
is_correct = (
91+
math_utils.grade_answer_mathd(model_answer, ground_truth)
92+
or math_utils.grade_answer_sympy(model_answer, ground_truth)
93+
or math_utils.grade_answer_special_handling(
94+
model_answer, ground_truth
95+
)
96+
)
9397
if is_correct:
9498
found_correct_answer = True
9599
reward_value: float = 1.0 # Base reward for a correct answer.

tunix/utils/math_utils.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Math utils for evaluating on Math Dataset like Math500 and AIME2024."""
1616

17+
from decimal import Decimal, ROUND_HALF_UP
1718
import re
1819
from absl import logging
1920
from pylatexenc import latex2text
@@ -438,6 +439,193 @@ def extract_boxed_answer(solution: str):
438439
return solution
439440

440441

442+
def _cleanup_invalid_empty_sqrt(expr: str) -> str:
443+
"""Fix malformed latex like `\\sqrt{}{3}` -> `\\sqrt{3}`."""
444+
return re.sub(r"sqrt\{\}", r"sqrt", expr)
445+
446+
447+
def _parse_special_decimal_interval(expr: str):
448+
"""Parse known recurring-decimal special cases to numeric intervals."""
449+
expr = expr.replace("$", "").replace(" ", "")
450+
m = re.fullmatch(r"([+-]?\d+)\.([0-9]*)\\overline\{([0-9])\}", expr)
451+
if m is not None:
452+
int_part = m.group(1)
453+
non_repeating_decimals = m.group(2)
454+
recurring_digit = m.group(3)
455+
456+
# Only support single-digit recurring blocks, e.g. `16.\overline{6}`.
457+
# Map to the interval formed by 1-decimal and 2-decimal rounded values,
458+
# so answers like `16.7` and `16.67` can both match.
459+
decimal_places = len(non_repeating_decimals)
460+
scale = Decimal(10) ** decimal_places
461+
value = (
462+
Decimal(int_part)
463+
+ Decimal(non_repeating_decimals or "0") / scale
464+
+ Decimal(recurring_digit) / (Decimal(9) * scale)
465+
)
466+
467+
rounded_1 = float(value.quantize(Decimal("0.1"), rounding=ROUND_HALF_UP))
468+
rounded_2 = float(value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))
469+
return (min(rounded_1, rounded_2), max(rounded_1, rounded_2))
470+
471+
try:
472+
value = float(expr)
473+
return (value, value)
474+
except Exception:
475+
return None
476+
477+
478+
def _intervals_overlap(
479+
interval_a: tuple[float, float], interval_b: tuple[float, float]
480+
):
481+
return not (interval_a[1] < interval_b[0] or interval_b[1] < interval_a[0])
482+
483+
484+
def _parse_interval_set(expr: str):
485+
"""Parse interval unions from either inequality or bracket notation."""
486+
expr = expr.lower().strip()
487+
expr = expr.replace("$", "")
488+
expr = expr.replace("≤", "\\le")
489+
expr = expr.replace("\\leq", "\\le")
490+
expr = expr.replace("<=", "\\le")
491+
expr = expr.replace("\\cup", "|")
492+
expr = expr.replace("∪", "|")
493+
expr = expr.replace("or", "|")
494+
expr = expr.replace(" ", "")
495+
496+
if not expr:
497+
return None
498+
499+
parts = [part for part in expr.split("|") if part]
500+
if not parts:
501+
return None
502+
503+
# First try interval notation: [a,b], (a,b], etc.
504+
intervals = []
505+
all_interval_notation = True
506+
for part in parts:
507+
m = re.fullmatch(
508+
r"([\[(])([+-]?(?:\d+(?:\.\d+)?|\.\d+)),([+-]?(?:\d+(?:\.\d+)?|\.\d+))([\])])",
509+
part,
510+
)
511+
if m is None:
512+
all_interval_notation = False
513+
break
514+
left = float(m.group(2))
515+
right = float(m.group(3))
516+
left_closed = m.group(1) == "["
517+
right_closed = m.group(4) == "]"
518+
519+
if left > right:
520+
left, right = right, left
521+
left_closed, right_closed = right_closed, left_closed
522+
intervals.append((left, right, left_closed, right_closed))
523+
524+
if all_interval_notation:
525+
return sorted(intervals)
526+
527+
# Then try inequalities: -5\lex\le1, -5\lex\le1, etc.
528+
intervals = []
529+
for part in parts:
530+
m = re.fullmatch(
531+
r"([+-]?(?:\d+(?:\.\d+)?|\.\d+))\\le[a-z]?\\le([+-]?(?:\d+(?:\.\d+)?|\.\d+))",
532+
part,
533+
)
534+
if m is None:
535+
return None
536+
left = float(m.group(1))
537+
right = float(m.group(2))
538+
if left > right:
539+
left, right = right, left
540+
intervals.append((left, right, True, True))
541+
542+
return sorted(intervals)
543+
544+
545+
def _match_recurring_decimal_special_case(
546+
given_clean: str, ground_truth_clean: str
547+
) -> bool:
548+
"""Handle recurring decimal overlaps for single-digit overline forms."""
549+
if not (
550+
re.search(r"[0-9]+\.\s*\\overline\{[0-9]\}", given_clean)
551+
or re.search(r"[0-9]+\.\s*\\overline\{[0-9]\}", ground_truth_clean)
552+
):
553+
return False
554+
555+
given_interval = _parse_special_decimal_interval(given_clean)
556+
ground_truth_interval = _parse_special_decimal_interval(ground_truth_clean)
557+
return (
558+
given_interval is not None
559+
and ground_truth_interval is not None
560+
and _intervals_overlap(given_interval, ground_truth_interval)
561+
)
562+
563+
564+
def _match_interval_union_special_case(
565+
given_clean: str, ground_truth_clean: str
566+
) -> bool:
567+
"""Handle inequality unions and interval unions as equivalent sets."""
568+
given_intervals = _parse_interval_set(given_clean)
569+
ground_truth_intervals = _parse_interval_set(ground_truth_clean)
570+
return (
571+
given_intervals is not None
572+
and ground_truth_intervals is not None
573+
and given_intervals == ground_truth_intervals
574+
)
575+
576+
577+
def _match_invalid_sqrt_special_case(
578+
given_answer: str,
579+
ground_truth: str,
580+
given_clean: str,
581+
ground_truth_clean: str,
582+
) -> bool:
583+
"""Handle malformed `sqrt{}` cleanup equivalence checks."""
584+
if given_clean == given_answer and ground_truth_clean == ground_truth:
585+
return False
586+
587+
given_normalized = _normalize(given_clean)
588+
ground_truth_normalized = _normalize(ground_truth_clean)
589+
if (
590+
given_normalized is not None
591+
and ground_truth_normalized is not None
592+
and given_normalized == ground_truth_normalized
593+
):
594+
return True
595+
return (
596+
given_normalized is not None
597+
and ground_truth_normalized is not None
598+
and len(given_normalized) > 0
599+
and are_equal_under_sympy(ground_truth_normalized, given_normalized)
600+
)
601+
602+
603+
def grade_answer_special_handling(given_answer: str, ground_truth: str) -> bool:
604+
if given_answer is None or ground_truth is None:
605+
return False
606+
# Only clean the ground truth for latex errors.
607+
ground_truth_clean = _cleanup_invalid_empty_sqrt(ground_truth)
608+
609+
if given_answer == ground_truth_clean:
610+
return True
611+
612+
# Case 1: recurring decimal overlap special handling.
613+
if _match_recurring_decimal_special_case(given_answer, ground_truth_clean):
614+
return True
615+
616+
# Case 2: malformed sqrt{} cleanups should still evaluate as equivalent.
617+
if _match_invalid_sqrt_special_case(
618+
given_answer, ground_truth, given_answer, ground_truth_clean
619+
):
620+
return True
621+
622+
# Case 3: inequality union vs interval union equivalence.
623+
if _match_interval_union_special_case(given_answer, ground_truth_clean):
624+
return True
625+
626+
return False
627+
628+
441629
def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
442630
"""Grades a given answer against a ground truth using sympy for evaluation."""
443631
ground_truth_normalized = _normalize(ground_truth)

0 commit comments

Comments
 (0)