|
14 | 14 |
|
15 | 15 | """Math utils for evaluating on Math Dataset like Math500 and AIME2024.""" |
16 | 16 |
|
| 17 | +from decimal import Decimal, ROUND_HALF_UP |
17 | 18 | import re |
18 | 19 | from absl import logging |
19 | 20 | from pylatexenc import latex2text |
@@ -438,6 +439,193 @@ def extract_boxed_answer(solution: str): |
438 | 439 | return solution |
439 | 440 |
|
440 | 441 |
|
| 442 | +def _cleanup_invalid_empty_sqrt(expr: str) -> str: |
| 443 | + """Fix malformed latex like `\\sqrt{}{3}` -> `\\sqrt{3}`.""" |
| 444 | + return re.sub(r"sqrt\{\}", r"sqrt", expr) |
| 445 | + |
| 446 | + |
| 447 | +def _parse_special_decimal_interval(expr: str): |
| 448 | + """Parse known recurring-decimal special cases to numeric intervals.""" |
| 449 | + expr = expr.replace("$", "").replace(" ", "") |
| 450 | + m = re.fullmatch(r"([+-]?\d+)\.([0-9]*)\\overline\{([0-9])\}", expr) |
| 451 | + if m is not None: |
| 452 | + int_part = m.group(1) |
| 453 | + non_repeating_decimals = m.group(2) |
| 454 | + recurring_digit = m.group(3) |
| 455 | + |
| 456 | + # Only support single-digit recurring blocks, e.g. `16.\overline{6}`. |
| 457 | + # Map to the interval formed by 1-decimal and 2-decimal rounded values, |
| 458 | + # so answers like `16.7` and `16.67` can both match. |
| 459 | + decimal_places = len(non_repeating_decimals) |
| 460 | + scale = Decimal(10) ** decimal_places |
| 461 | + value = ( |
| 462 | + Decimal(int_part) |
| 463 | + + Decimal(non_repeating_decimals or "0") / scale |
| 464 | + + Decimal(recurring_digit) / (Decimal(9) * scale) |
| 465 | + ) |
| 466 | + |
| 467 | + rounded_1 = float(value.quantize(Decimal("0.1"), rounding=ROUND_HALF_UP)) |
| 468 | + rounded_2 = float(value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)) |
| 469 | + return (min(rounded_1, rounded_2), max(rounded_1, rounded_2)) |
| 470 | + |
| 471 | + try: |
| 472 | + value = float(expr) |
| 473 | + return (value, value) |
| 474 | + except Exception: |
| 475 | + return None |
| 476 | + |
| 477 | + |
| 478 | +def _intervals_overlap( |
| 479 | + interval_a: tuple[float, float], interval_b: tuple[float, float] |
| 480 | +): |
| 481 | + return not (interval_a[1] < interval_b[0] or interval_b[1] < interval_a[0]) |
| 482 | + |
| 483 | + |
| 484 | +def _parse_interval_set(expr: str): |
| 485 | + """Parse interval unions from either inequality or bracket notation.""" |
| 486 | + expr = expr.lower().strip() |
| 487 | + expr = expr.replace("$", "") |
| 488 | + expr = expr.replace("≤", "\\le") |
| 489 | + expr = expr.replace("\\leq", "\\le") |
| 490 | + expr = expr.replace("<=", "\\le") |
| 491 | + expr = expr.replace("\\cup", "|") |
| 492 | + expr = expr.replace("∪", "|") |
| 493 | + expr = expr.replace("or", "|") |
| 494 | + expr = expr.replace(" ", "") |
| 495 | + |
| 496 | + if not expr: |
| 497 | + return None |
| 498 | + |
| 499 | + parts = [part for part in expr.split("|") if part] |
| 500 | + if not parts: |
| 501 | + return None |
| 502 | + |
| 503 | + # First try interval notation: [a,b], (a,b], etc. |
| 504 | + intervals = [] |
| 505 | + all_interval_notation = True |
| 506 | + for part in parts: |
| 507 | + m = re.fullmatch( |
| 508 | + r"([\[(])([+-]?(?:\d+(?:\.\d+)?|\.\d+)),([+-]?(?:\d+(?:\.\d+)?|\.\d+))([\])])", |
| 509 | + part, |
| 510 | + ) |
| 511 | + if m is None: |
| 512 | + all_interval_notation = False |
| 513 | + break |
| 514 | + left = float(m.group(2)) |
| 515 | + right = float(m.group(3)) |
| 516 | + left_closed = m.group(1) == "[" |
| 517 | + right_closed = m.group(4) == "]" |
| 518 | + |
| 519 | + if left > right: |
| 520 | + left, right = right, left |
| 521 | + left_closed, right_closed = right_closed, left_closed |
| 522 | + intervals.append((left, right, left_closed, right_closed)) |
| 523 | + |
| 524 | + if all_interval_notation: |
| 525 | + return sorted(intervals) |
| 526 | + |
| 527 | + # Then try inequalities: -5\lex\le1, -5\lex\le1, etc. |
| 528 | + intervals = [] |
| 529 | + for part in parts: |
| 530 | + m = re.fullmatch( |
| 531 | + r"([+-]?(?:\d+(?:\.\d+)?|\.\d+))\\le[a-z]?\\le([+-]?(?:\d+(?:\.\d+)?|\.\d+))", |
| 532 | + part, |
| 533 | + ) |
| 534 | + if m is None: |
| 535 | + return None |
| 536 | + left = float(m.group(1)) |
| 537 | + right = float(m.group(2)) |
| 538 | + if left > right: |
| 539 | + left, right = right, left |
| 540 | + intervals.append((left, right, True, True)) |
| 541 | + |
| 542 | + return sorted(intervals) |
| 543 | + |
| 544 | + |
| 545 | +def _match_recurring_decimal_special_case( |
| 546 | + given_clean: str, ground_truth_clean: str |
| 547 | +) -> bool: |
| 548 | + """Handle recurring decimal overlaps for single-digit overline forms.""" |
| 549 | + if not ( |
| 550 | + re.search(r"[0-9]+\.\s*\\overline\{[0-9]\}", given_clean) |
| 551 | + or re.search(r"[0-9]+\.\s*\\overline\{[0-9]\}", ground_truth_clean) |
| 552 | + ): |
| 553 | + return False |
| 554 | + |
| 555 | + given_interval = _parse_special_decimal_interval(given_clean) |
| 556 | + ground_truth_interval = _parse_special_decimal_interval(ground_truth_clean) |
| 557 | + return ( |
| 558 | + given_interval is not None |
| 559 | + and ground_truth_interval is not None |
| 560 | + and _intervals_overlap(given_interval, ground_truth_interval) |
| 561 | + ) |
| 562 | + |
| 563 | + |
| 564 | +def _match_interval_union_special_case( |
| 565 | + given_clean: str, ground_truth_clean: str |
| 566 | +) -> bool: |
| 567 | + """Handle inequality unions and interval unions as equivalent sets.""" |
| 568 | + given_intervals = _parse_interval_set(given_clean) |
| 569 | + ground_truth_intervals = _parse_interval_set(ground_truth_clean) |
| 570 | + return ( |
| 571 | + given_intervals is not None |
| 572 | + and ground_truth_intervals is not None |
| 573 | + and given_intervals == ground_truth_intervals |
| 574 | + ) |
| 575 | + |
| 576 | + |
| 577 | +def _match_invalid_sqrt_special_case( |
| 578 | + given_answer: str, |
| 579 | + ground_truth: str, |
| 580 | + given_clean: str, |
| 581 | + ground_truth_clean: str, |
| 582 | +) -> bool: |
| 583 | + """Handle malformed `sqrt{}` cleanup equivalence checks.""" |
| 584 | + if given_clean == given_answer and ground_truth_clean == ground_truth: |
| 585 | + return False |
| 586 | + |
| 587 | + given_normalized = _normalize(given_clean) |
| 588 | + ground_truth_normalized = _normalize(ground_truth_clean) |
| 589 | + if ( |
| 590 | + given_normalized is not None |
| 591 | + and ground_truth_normalized is not None |
| 592 | + and given_normalized == ground_truth_normalized |
| 593 | + ): |
| 594 | + return True |
| 595 | + return ( |
| 596 | + given_normalized is not None |
| 597 | + and ground_truth_normalized is not None |
| 598 | + and len(given_normalized) > 0 |
| 599 | + and are_equal_under_sympy(ground_truth_normalized, given_normalized) |
| 600 | + ) |
| 601 | + |
| 602 | + |
| 603 | +def grade_answer_special_handling(given_answer: str, ground_truth: str) -> bool: |
| 604 | + if given_answer is None or ground_truth is None: |
| 605 | + return False |
| 606 | + # Only clean the ground truth for latex errors. |
| 607 | + ground_truth_clean = _cleanup_invalid_empty_sqrt(ground_truth) |
| 608 | + |
| 609 | + if given_answer == ground_truth_clean: |
| 610 | + return True |
| 611 | + |
| 612 | + # Case 1: recurring decimal overlap special handling. |
| 613 | + if _match_recurring_decimal_special_case(given_answer, ground_truth_clean): |
| 614 | + return True |
| 615 | + |
| 616 | + # Case 2: malformed sqrt{} cleanups should still evaluate as equivalent. |
| 617 | + if _match_invalid_sqrt_special_case( |
| 618 | + given_answer, ground_truth, given_answer, ground_truth_clean |
| 619 | + ): |
| 620 | + return True |
| 621 | + |
| 622 | + # Case 3: inequality union vs interval union equivalence. |
| 623 | + if _match_interval_union_special_case(given_answer, ground_truth_clean): |
| 624 | + return True |
| 625 | + |
| 626 | + return False |
| 627 | + |
| 628 | + |
441 | 629 | def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool: |
442 | 630 | """Grades a given answer against a ground truth using sympy for evaluation.""" |
443 | 631 | ground_truth_normalized = _normalize(ground_truth) |
|
0 commit comments