Skip to content

Commit 39fa8ca

Browse files
committed
Refactor equality_checker.py for improved readability and functionality
1 parent bfd9748 commit 39fa8ca

File tree

3 files changed

+118
-61
lines changed

3 files changed

+118
-61
lines changed

requirements.txt

922 Bytes
Binary file not shown.

src/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import jinja2
99
import numpy as np
1010
import requests
11-
from tqdm import tqdm
11+
from tqdm.auto import tqdm
1212
from concurrent.futures import ThreadPoolExecutor
1313

1414
from .types import EvalResult, SingleEvalResult

src/equality_checker.py

Lines changed: 117 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,150 @@
1-
import numpy as np
2-
import pandas as pd
3-
import json
4-
import sympy
5-
from sympy.parsing.latex import parse_latex
6-
import os
71
import re
82
from fractions import Fraction
3+
import sympy
4+
from sympy.parsing.sympy_parser import parse_expr
5+
6+
# Если вдруг antlr4-python3-runtime не установлен, то не используем его
7+
try:
8+
from sympy.parsing.latex import parse_latex # type: ignore
9+
10+
_HAS_PARSE_LATEX = True
11+
except Exception:
12+
_HAS_PARSE_LATEX = False
913

1014

1115
class DoomSlayer:
12-
def __init__(self, EPS=1e-2):
16+
def __init__(self, EPS: float = 1e-2) -> None:
1317
self.EPS = EPS
1418
self.num_pattern = re.compile(r"-?\d+(?:[.,]\d+)?$")
1519
self.frac_pattern = re.compile(r"-?\d+\s*/\s*\d+$")
16-
self.latex_frac_pattern = re.compile(r"\\frac\{(-?\d+)\}\{(\d+)\}")
17-
self.minus_map = {"\u2212": "-", "\u2013": "-", "\u2014": "-"} # ADDED
20+
# допускаем обрамления \( … \), $ … $, $$ … $$
21+
self.latex_frac_pattern = re.compile(
22+
r"(?:\\\(|\$\$?)?\s*\\frac\{(-?\d+)\}\{(\d+)\}\s*(?:\\\)|\$\$?)?"
23+
)
24+
self.minus_map = {"\u2212": "-", "\u2013": "-", "\u2014": "-"}
25+
26+
# ──────────────────────────── service ────────────────────────────
27+
def _strip_delims(self, s: str) -> str:
28+
s = s.strip()
29+
if s.startswith("$$") and s.endswith("$$"):
30+
s = s[2:-2]
31+
elif s.startswith("$") and s.endswith("$"):
32+
s = s[1:-1]
33+
if s.startswith(r"\(") and s.endswith(r"\)"):
34+
s = s[2:-2]
35+
elif s.startswith(r"\[") and s.endswith(r"\]"):
36+
s = s[2:-2]
37+
return s.strip()
1838

1939
def _normalize(self, s: str) -> str:
20-
for uni, ascii_minus in self.minus_map.items(): # ADDED
21-
s = s.replace(uni, ascii_minus) # ADDED
22-
return s.strip() # ADDED
40+
for bad, good in self.minus_map.items():
41+
s = s.replace(bad, good)
42+
return s.strip()
2343

44+
# ──────────────────────────── preprocessing ────────────────────────────
2445
def preprocess_answer(self, answer: str, hard: bool):
25-
answer = self._normalize(answer) # ADDED
46+
answer = self._normalize(answer)
2647
answer = answer[:-1] if answer.endswith(".") else answer
2748
if not hard:
2849
return re.findall(r"-?\d+(?:[.,]\d+)?", answer)
29-
return answer.lower().replace("**", "^").split(";")
50+
return [
51+
self._strip_delims(part).lower().replace("**", "^").strip()
52+
for part in answer.split(";")
53+
]
3054

31-
def __call__(self, answer: str, predict: str) -> bool:
32-
if not answer or not predict:
55+
# ──────────────────────────── helpers ────────────────────────────
56+
def _compare_numeric(self, a: str, b: str) -> bool:
57+
"""Абс- и относительная погрешность"""
58+
try:
59+
fa, fb = float(a.replace(",", ".")), float(b.replace(",", "."))
60+
except ValueError:
3361
return False
34-
answer = self._normalize(answer) # ADDED
35-
predict = self._normalize(predict) # ADDED
36-
37-
if self._compare_fraction(answer, predict):
38-
return True
39-
40-
if self.num_pattern.match(answer) and self.num_pattern.match(predict):
41-
return self.simple_check(predict, answer) or self.latex_equivalent(
42-
predict, answer
43-
)
44-
45-
return self.latex_equivalent(predict, answer)
46-
47-
def simple_check(self, predict: str, answer: str) -> bool:
48-
p = self.preprocess_answer(predict, False)
49-
a = self.preprocess_answer(answer, False)
50-
return "".join(a).replace(",", ".") == "".join(p).replace(",", ".")
62+
diff = abs(fa - fb)
63+
return diff <= self.EPS or diff / (abs(fb) or 1) <= self.EPS
5164

5265
def _compare_fraction(self, s1: str, s2: str) -> bool:
53-
def to_frac(s):
54-
s = s.strip()
55-
m = self.frac_pattern.fullmatch(s)
56-
if m:
66+
def to_frac(s: str):
67+
s = self._strip_delims(s)
68+
if self.frac_pattern.fullmatch(s):
5769
num, den = map(int, s.split("/"))
5870
return Fraction(num, den)
59-
m2 = self.latex_frac_pattern.fullmatch(s)
60-
if m2:
61-
num, den = map(int, m2.groups())
71+
m = self.latex_frac_pattern.fullmatch(s)
72+
if m:
73+
num, den = map(int, m.groups())
6274
return Fraction(num, den)
6375
return None
6476

65-
f1 = to_frac(s1)
66-
f2 = to_frac(s2)
77+
f1, f2 = to_frac(s1), to_frac(s2)
6778
if f1 is not None and f2 is not None:
6879
return abs(float(f1) - float(f2)) <= self.EPS
6980
return False
7081

71-
def latex_equivalent(self, latex1: str, latex2: str) -> bool:
72-
parts1 = self.preprocess_answer(latex1, True)
73-
parts2 = self.preprocess_answer(latex2, True)
74-
if len(parts1) != len(parts2):
75-
return False
76-
77-
for a, b in zip(parts1, parts2):
82+
def _to_expr(self, s: str):
83+
"""Пытаемся превратить строку в sympy-выражение максимально надёжно"""
84+
# 1) голое число
85+
try:
86+
return sympy.Float(s.replace(",", "."))
87+
except Exception:
88+
pass
89+
# 2) обычная «python-математика»
90+
try:
91+
return parse_expr(s.replace("^", "**"), evaluate=True)
92+
except Exception:
93+
pass
94+
# 3) LaTeX (если библиотека доступна)
95+
if _HAS_PARSE_LATEX:
7896
try:
79-
e1 = parse_latex(a)
80-
e2 = parse_latex(b)
81-
diff = sympy.simplify(abs(e1 - e2))
82-
try:
83-
diff_rel = sympy.simplify(abs(e1 - e2) / abs(e2))
84-
diff = min(diff, diff_rel)
85-
except Exception:
86-
pass
87-
if diff > self.EPS:
88-
return False
97+
return parse_latex(self._strip_delims(s))
8998
except Exception:
99+
pass
100+
return None
101+
102+
def _expr_equal(self, e1, e2) -> bool:
103+
diff = sympy.simplify(e1 - e2)
104+
if diff == 0:
105+
return True # точное равенство
106+
if diff.free_symbols: # осталось x, y … — считаем неравными
107+
return False
108+
try:
109+
return abs(float(diff)) <= self.EPS
110+
except Exception:
111+
return False
112+
113+
# ──────────────────────────── core ────────────────────────────
114+
def latex_equivalent(self, s1: str, s2: str) -> bool:
115+
p1, p2 = self.preprocess_answer(s1, True), self.preprocess_answer(s2, True)
116+
if len(p1) != len(p2):
117+
return False
118+
119+
for a, b in zip(p1, p2):
120+
# быстрое сравнение чисел
121+
if self._compare_numeric(a, b):
122+
continue
123+
# пытаемся превратить в выражения
124+
e1, e2 = self._to_expr(a), self._to_expr(b)
125+
if e1 is None or e2 is None:
90126
if a != b:
91127
return False
92-
128+
continue
129+
if not self._expr_equal(e1, e2):
130+
return False
93131
return True
132+
133+
# ──────────────────────────── public API ────────────────────────────
134+
def __call__(
135+
self, answer: str, predict: str
136+
) -> bool: # order: (правильный ответ, предикт)
137+
if not answer or not predict:
138+
return False
139+
answer, predict = self._normalize(answer), self._normalize(predict)
140+
141+
if self._compare_fraction(answer, predict):
142+
return True
143+
144+
if self.num_pattern.match(answer) and self.num_pattern.match(predict):
145+
if self._compare_numeric(answer, predict):
146+
return True
147+
# если равенство «почти» не прокатило – проверяем как выражения
148+
return self.latex_equivalent(predict, answer)
149+
150+
return self.latex_equivalent(predict, answer)

0 commit comments

Comments
 (0)