|
1 | | -import numpy as np |
2 | | -import pandas as pd |
3 | | -import json |
4 | | -import sympy |
5 | | -from sympy.parsing.latex import parse_latex |
6 | | -import os |
7 | 1 | import re |
8 | 2 | from fractions import Fraction |
| 3 | +import sympy |
| 4 | +from sympy.parsing.sympy_parser import parse_expr |
| 5 | + |
| 6 | +# Если вдруг antlr4-python3-runtime не установлен, то не используем его |
| 7 | +try: |
| 8 | + from sympy.parsing.latex import parse_latex # type: ignore |
| 9 | + |
| 10 | + _HAS_PARSE_LATEX = True |
| 11 | +except Exception: |
| 12 | + _HAS_PARSE_LATEX = False |
9 | 13 |
|
10 | 14 |
|
11 | 15 | class DoomSlayer: |
12 | | - def __init__(self, EPS=1e-2): |
| 16 | + def __init__(self, EPS: float = 1e-2) -> None: |
13 | 17 | self.EPS = EPS |
14 | 18 | self.num_pattern = re.compile(r"-?\d+(?:[.,]\d+)?$") |
15 | 19 | self.frac_pattern = re.compile(r"-?\d+\s*/\s*\d+$") |
16 | | - self.latex_frac_pattern = re.compile(r"\\frac\{(-?\d+)\}\{(\d+)\}") |
17 | | - self.minus_map = {"\u2212": "-", "\u2013": "-", "\u2014": "-"} # ADDED |
| 20 | + # допускаем обрамления \( … \), $ … $, $$ … $$ |
| 21 | + self.latex_frac_pattern = re.compile( |
| 22 | + r"(?:\\\(|\$\$?)?\s*\\frac\{(-?\d+)\}\{(\d+)\}\s*(?:\\\)|\$\$?)?" |
| 23 | + ) |
| 24 | + self.minus_map = {"\u2212": "-", "\u2013": "-", "\u2014": "-"} |
| 25 | + |
| 26 | + # ──────────────────────────── service ──────────────────────────── |
| 27 | + def _strip_delims(self, s: str) -> str: |
| 28 | + s = s.strip() |
| 29 | + if s.startswith("$$") and s.endswith("$$"): |
| 30 | + s = s[2:-2] |
| 31 | + elif s.startswith("$") and s.endswith("$"): |
| 32 | + s = s[1:-1] |
| 33 | + if s.startswith(r"\(") and s.endswith(r"\)"): |
| 34 | + s = s[2:-2] |
| 35 | + elif s.startswith(r"\[") and s.endswith(r"\]"): |
| 36 | + s = s[2:-2] |
| 37 | + return s.strip() |
18 | 38 |
|
19 | 39 | def _normalize(self, s: str) -> str: |
20 | | - for uni, ascii_minus in self.minus_map.items(): # ADDED |
21 | | - s = s.replace(uni, ascii_minus) # ADDED |
22 | | - return s.strip() # ADDED |
| 40 | + for bad, good in self.minus_map.items(): |
| 41 | + s = s.replace(bad, good) |
| 42 | + return s.strip() |
23 | 43 |
|
| 44 | + # ──────────────────────────── preprocessing ──────────────────────────── |
24 | 45 | def preprocess_answer(self, answer: str, hard: bool): |
25 | | - answer = self._normalize(answer) # ADDED |
| 46 | + answer = self._normalize(answer) |
26 | 47 | answer = answer[:-1] if answer.endswith(".") else answer |
27 | 48 | if not hard: |
28 | 49 | return re.findall(r"-?\d+(?:[.,]\d+)?", answer) |
29 | | - return answer.lower().replace("**", "^").split(";") |
| 50 | + return [ |
| 51 | + self._strip_delims(part).lower().replace("**", "^").strip() |
| 52 | + for part in answer.split(";") |
| 53 | + ] |
30 | 54 |
|
31 | | - def __call__(self, answer: str, predict: str) -> bool: |
32 | | - if not answer or not predict: |
| 55 | + # ──────────────────────────── helpers ──────────────────────────── |
| 56 | + def _compare_numeric(self, a: str, b: str) -> bool: |
| 57 | + """Абс- и относительная погрешность""" |
| 58 | + try: |
| 59 | + fa, fb = float(a.replace(",", ".")), float(b.replace(",", ".")) |
| 60 | + except ValueError: |
33 | 61 | return False |
34 | | - answer = self._normalize(answer) # ADDED |
35 | | - predict = self._normalize(predict) # ADDED |
36 | | - |
37 | | - if self._compare_fraction(answer, predict): |
38 | | - return True |
39 | | - |
40 | | - if self.num_pattern.match(answer) and self.num_pattern.match(predict): |
41 | | - return self.simple_check(predict, answer) or self.latex_equivalent( |
42 | | - predict, answer |
43 | | - ) |
44 | | - |
45 | | - return self.latex_equivalent(predict, answer) |
46 | | - |
47 | | - def simple_check(self, predict: str, answer: str) -> bool: |
48 | | - p = self.preprocess_answer(predict, False) |
49 | | - a = self.preprocess_answer(answer, False) |
50 | | - return "".join(a).replace(",", ".") == "".join(p).replace(",", ".") |
| 62 | + diff = abs(fa - fb) |
| 63 | + return diff <= self.EPS or diff / (abs(fb) or 1) <= self.EPS |
51 | 64 |
|
52 | 65 | def _compare_fraction(self, s1: str, s2: str) -> bool: |
53 | | - def to_frac(s): |
54 | | - s = s.strip() |
55 | | - m = self.frac_pattern.fullmatch(s) |
56 | | - if m: |
| 66 | + def to_frac(s: str): |
| 67 | + s = self._strip_delims(s) |
| 68 | + if self.frac_pattern.fullmatch(s): |
57 | 69 | num, den = map(int, s.split("/")) |
58 | 70 | return Fraction(num, den) |
59 | | - m2 = self.latex_frac_pattern.fullmatch(s) |
60 | | - if m2: |
61 | | - num, den = map(int, m2.groups()) |
| 71 | + m = self.latex_frac_pattern.fullmatch(s) |
| 72 | + if m: |
| 73 | + num, den = map(int, m.groups()) |
62 | 74 | return Fraction(num, den) |
63 | 75 | return None |
64 | 76 |
|
65 | | - f1 = to_frac(s1) |
66 | | - f2 = to_frac(s2) |
| 77 | + f1, f2 = to_frac(s1), to_frac(s2) |
67 | 78 | if f1 is not None and f2 is not None: |
68 | 79 | return abs(float(f1) - float(f2)) <= self.EPS |
69 | 80 | return False |
70 | 81 |
|
71 | | - def latex_equivalent(self, latex1: str, latex2: str) -> bool: |
72 | | - parts1 = self.preprocess_answer(latex1, True) |
73 | | - parts2 = self.preprocess_answer(latex2, True) |
74 | | - if len(parts1) != len(parts2): |
75 | | - return False |
76 | | - |
77 | | - for a, b in zip(parts1, parts2): |
| 82 | + def _to_expr(self, s: str): |
| 83 | + """Пытаемся превратить строку в sympy-выражение максимально надёжно""" |
| 84 | + # 1) голое число |
| 85 | + try: |
| 86 | + return sympy.Float(s.replace(",", ".")) |
| 87 | + except Exception: |
| 88 | + pass |
| 89 | + # 2) обычная «python-математика» |
| 90 | + try: |
| 91 | + return parse_expr(s.replace("^", "**"), evaluate=True) |
| 92 | + except Exception: |
| 93 | + pass |
| 94 | + # 3) LaTeX (если библиотека доступна) |
| 95 | + if _HAS_PARSE_LATEX: |
78 | 96 | try: |
79 | | - e1 = parse_latex(a) |
80 | | - e2 = parse_latex(b) |
81 | | - diff = sympy.simplify(abs(e1 - e2)) |
82 | | - try: |
83 | | - diff_rel = sympy.simplify(abs(e1 - e2) / abs(e2)) |
84 | | - diff = min(diff, diff_rel) |
85 | | - except Exception: |
86 | | - pass |
87 | | - if diff > self.EPS: |
88 | | - return False |
| 97 | + return parse_latex(self._strip_delims(s)) |
89 | 98 | except Exception: |
| 99 | + pass |
| 100 | + return None |
| 101 | + |
| 102 | + def _expr_equal(self, e1, e2) -> bool: |
| 103 | + diff = sympy.simplify(e1 - e2) |
| 104 | + if diff == 0: |
| 105 | + return True # точное равенство |
| 106 | + if diff.free_symbols: # осталось x, y … — считаем неравными |
| 107 | + return False |
| 108 | + try: |
| 109 | + return abs(float(diff)) <= self.EPS |
| 110 | + except Exception: |
| 111 | + return False |
| 112 | + |
| 113 | + # ──────────────────────────── core ──────────────────────────── |
| 114 | + def latex_equivalent(self, s1: str, s2: str) -> bool: |
| 115 | + p1, p2 = self.preprocess_answer(s1, True), self.preprocess_answer(s2, True) |
| 116 | + if len(p1) != len(p2): |
| 117 | + return False |
| 118 | + |
| 119 | + for a, b in zip(p1, p2): |
| 120 | + # быстрое сравнение чисел |
| 121 | + if self._compare_numeric(a, b): |
| 122 | + continue |
| 123 | + # пытаемся превратить в выражения |
| 124 | + e1, e2 = self._to_expr(a), self._to_expr(b) |
| 125 | + if e1 is None or e2 is None: |
90 | 126 | if a != b: |
91 | 127 | return False |
92 | | - |
| 128 | + continue |
| 129 | + if not self._expr_equal(e1, e2): |
| 130 | + return False |
93 | 131 | return True |
| 132 | + |
| 133 | + # ──────────────────────────── public API ──────────────────────────── |
| 134 | + def __call__( |
| 135 | + self, answer: str, predict: str |
| 136 | + ) -> bool: # order: (правильный ответ, предикт) |
| 137 | + if not answer or not predict: |
| 138 | + return False |
| 139 | + answer, predict = self._normalize(answer), self._normalize(predict) |
| 140 | + |
| 141 | + if self._compare_fraction(answer, predict): |
| 142 | + return True |
| 143 | + |
| 144 | + if self.num_pattern.match(answer) and self.num_pattern.match(predict): |
| 145 | + if self._compare_numeric(answer, predict): |
| 146 | + return True |
| 147 | + # если равенство «почти» не прокатило – проверяем как выражения |
| 148 | + return self.latex_equivalent(predict, answer) |
| 149 | + |
| 150 | + return self.latex_equivalent(predict, answer) |
0 commit comments