Skip to content

Commit 6f25a50

Browse files
committed
Add retry option for incomplete model evaluations and improve error handling in OaiSampler
1 parent 39fa8ca commit 6f25a50

File tree

4 files changed

+206
-65
lines changed

4 files changed

+206
-65
lines changed

runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def main() -> None:
5555
default="all",
5656
help="Выбор датасета для оценки: all (все), russianmath, physics (по умолчанию: all)",
5757
)
58+
parser.add_argument(
59+
"--retry-incomplete",
60+
action="store_true",
61+
help="Перезапустить оценку моделей с неполными результатами",
62+
)
5863
args = parser.parse_args()
5964

6065
# Загружаем конфиг
@@ -73,7 +78,7 @@ def main() -> None:
7378

7479

7580
# Создаем и инициализируем лидерборд
76-
leaderboard = Leaderboard(args.config, max_workers=args.max_workers)
81+
leaderboard = Leaderboard(args.config, max_workers=args.max_workers, retry_incomplete=args.retry_incomplete)
7782

7883
# Определяем системные промпты для каждой модели из конфига
7984
system_prompts: Dict[str, Optional[str]] = {

src/equality_checker.py

Lines changed: 128 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,45 @@
1+
from __future__ import annotations
2+
13
import re
24
from fractions import Fraction
5+
from typing import Sequence
6+
37
import sympy
48
from sympy.parsing.sympy_parser import parse_expr
59

6-
# Если вдруг antlr4-python3-runtime не установлен, то не используем его
10+
# Если вдруг antlr4-python3-runtime не установлен, ­игнорируем latex-парсер
711
try:
812
from sympy.parsing.latex import parse_latex # type: ignore
913

1014
_HAS_PARSE_LATEX = True
11-
except Exception:
15+
except Exception: # pragma: no cover
1216
_HAS_PARSE_LATEX = False
1317

1418

1519
class DoomSlayer:
20+
"""
21+
Проверка эквивалентности ответов (строка-к-строке).
22+
23+
Главная идея ─ сначала пытаемся сравнить как числа,
24+
затем как дроби, затем как символьные выражения (sympy/LaTeX).
25+
"""
26+
1627
def __init__(self, EPS: float = 1e-2) -> None:
1728
self.EPS = EPS
18-
self.num_pattern = re.compile(r"-?\d+(?:[.,]\d+)?$")
29+
30+
# ────── основные шаблоны ──────
31+
self.num_pattern = re.compile(r"-?\d+(?:[.,]\d+)?(?:[eE]-?\d+)?$")
1932
self.frac_pattern = re.compile(r"-?\d+\s*/\s*\d+$")
20-
# допускаем обрамления \( … \), $ … $, $$ … $$
33+
# допускаем обрамление \( … \) , $ … $ , $$ … $$
2134
self.latex_frac_pattern = re.compile(
2235
r"(?:\\\(|\$\$?)?\s*\\frac\{(-?\d+)\}\{(\d+)\}\s*(?:\\\)|\$\$?)?"
2336
)
37+
# различные варианты "минуса"
2438
self.minus_map = {"\u2212": "-", "\u2013": "-", "\u2014": "-"}
2539

2640
# ──────────────────────────── service ────────────────────────────
2741
def _strip_delims(self, s: str) -> str:
42+
"""Убираем окружающие $ … $, $$ … $$, \( … \), \[ … \]"""
2843
s = s.strip()
2944
if s.startswith("$$") and s.endswith("$$"):
3045
s = s[2:-2]
@@ -37,33 +52,44 @@ def _strip_delims(self, s: str) -> str:
3752
return s.strip()
3853

3954
def _normalize(self, s: str) -> str:
55+
"""Замена длинных/узких минусов, обрезка пробелов"""
4056
for bad, good in self.minus_map.items():
4157
s = s.replace(bad, good)
4258
return s.strip()
4359

4460
# ──────────────────────────── preprocessing ────────────────────────────
45-
def preprocess_answer(self, answer: str, hard: bool):
61+
def preprocess_answer(self, answer: str, hard: bool) -> list[str]:
62+
"""
63+
* soft-режим (hard = False) нужен только для _compare_numeric —
64+
вытягиваем **все** числа в строке;
65+
* hard-режим - для символьного сравнения —
66+
разбиваем строку по «;» на отдельные выражения,
67+
убираем внешние $$, \( \) и приводим **^ → \*\*** .
68+
"""
4669
answer = self._normalize(answer)
4770
answer = answer[:-1] if answer.endswith(".") else answer
71+
4872
if not hard:
49-
return re.findall(r"-?\d+(?:[.,]\d+)?", answer)
73+
return re.findall(r"-?\d+(?:[.,]\d+)?(?:[eE]-?\d+)?", answer)
74+
5075
return [
5176
self._strip_delims(part).lower().replace("**", "^").strip()
5277
for part in answer.split(";")
5378
]
5479

5580
# ──────────────────────────── helpers ────────────────────────────
5681
def _compare_numeric(self, a: str, b: str) -> bool:
57-
"""Абс- и относительная погрешность"""
82+
"""Абсолютная и относительная погрешность для чисел/научной нотации"""
5883
try:
5984
fa, fb = float(a.replace(",", ".")), float(b.replace(",", "."))
6085
except ValueError:
6186
return False
6287
diff = abs(fa - fb)
63-
return diff <= self.EPS or diff / (abs(fb) or 1) <= self.EPS
88+
return diff <= self.EPS or diff / (abs(fb) or 1.0) <= self.EPS
6489

6590
def _compare_fraction(self, s1: str, s2: str) -> bool:
66-
def to_frac(s: str):
91+
"""Сравнение обыкновенных дробей (в т. ч. LaTeX)"""
92+
def to_frac(s: str) -> Fraction | None:
6793
s = self._strip_delims(s)
6894
if self.frac_pattern.fullmatch(s):
6995
num, den = map(int, s.split("/"))
@@ -80,71 +106,144 @@ def to_frac(s: str):
80106
return False
81107

82108
def _to_expr(self, s: str):
83-
"""Пытаемся превратить строку в sympy-выражение максимально надёжно"""
84-
# 1) голое число
109+
"""
110+
Превращаем строку в sympy-объект (Float | Expr | Tuple | Matrix …).
111+
Пытаемся по очереди:
112+
1) число → sympy.Float
113+
2) python-математика («2^3» → «2**3») → parse_expr
114+
3) LaTeX (если доступен parse_latex)
115+
"""
116+
s_clean = s.replace(",", ".") # «1,2» → «1.2»
85117
try:
86-
return sympy.Float(s.replace(",", "."))
118+
return sympy.Float(s_clean)
87119
except Exception:
88120
pass
89-
# 2) обычная «python-математика»
121+
90122
try:
91-
return parse_expr(s.replace("^", "**"), evaluate=True)
123+
return parse_expr(s_clean.replace("^", "**"), evaluate=True)
92124
except Exception:
93125
pass
94-
# 3) LaTeX (если библиотека доступна)
126+
95127
if _HAS_PARSE_LATEX:
96128
try:
97-
return parse_latex(self._strip_delims(s))
129+
return parse_latex(self._strip_delims(s_clean))
98130
except Exception:
99131
pass
100-
return None
132+
133+
return None # ничего не получилось
134+
135+
# ──────────────────────────── expression equality ─────────────────────────
136+
def _iterable_equal(
137+
self,
138+
seq1: Sequence,
139+
seq2: Sequence,
140+
) -> bool:
141+
"""Рекурсивное покомпонентное сравнение кортежей / списков / матриц"""
142+
if len(seq1) != len(seq2):
143+
return False
144+
return all(self._expr_equal(a, b) for a, b in zip(seq1, seq2))
101145

102146
def _expr_equal(self, e1, e2) -> bool:
103-
diff = sympy.simplify(e1 - e2)
147+
"""
148+
Надёжное сравнение любых sympy-объектов:
149+
* скаляры (Float, Integer, Symbol …),
150+
* кортежи (sympy.Tuple / обычный tuple),
151+
* матрицы (sympy.MatrixBase).
152+
"""
153+
# ─── кортежи / списки ───
154+
if isinstance(e1, (tuple, sympy.Tuple)) or isinstance(
155+
e2, (tuple, sympy.Tuple)
156+
):
157+
if not isinstance(e1, (tuple, sympy.Tuple)) or not isinstance(
158+
e2, (tuple, sympy.Tuple)
159+
):
160+
return False
161+
return self._iterable_equal(e1, e2)
162+
163+
# ─── матрицы ───
164+
from sympy.matrices.matrices import MatrixBase # локальный импорт → быстрее старт
165+
if isinstance(e1, MatrixBase) or isinstance(e2, MatrixBase): # pragma: no branch
166+
if not (isinstance(e1, MatrixBase) and isinstance(e2, MatrixBase)):
167+
return False
168+
if e1.shape != e2.shape:
169+
return False
170+
return self._iterable_equal(tuple(e1), tuple(e2))
171+
172+
# ─── обычные выражения ───
173+
try:
174+
diff = sympy.simplify(e1 - e2)
175+
except TypeError:
176+
# например, «tuple - Float» — значит несопоставимые типы
177+
return False
178+
104179
if diff == 0:
105180
return True # точное равенство
106-
if diff.free_symbols: # осталось x, y … — считаем неравными
181+
182+
if diff.free_symbols: # остались x, y … → не удалось упростить
107183
return False
184+
108185
try:
109186
return abs(float(diff)) <= self.EPS
110-
except Exception:
187+
except Exception: # pragma: no cover
111188
return False
112189

113190
# ──────────────────────────── core ────────────────────────────
114191
def latex_equivalent(self, s1: str, s2: str) -> bool:
115-
p1, p2 = self.preprocess_answer(s1, True), self.preprocess_answer(s2, True)
192+
"""
193+
«Тяжёлое» сравнение:
194+
* разбиваем ответы по «;» (мультиответ),
195+
* сравниваем соответствующие пары.
196+
Порядок элементов **важен**.
197+
"""
198+
p1 = self.preprocess_answer(s1, hard=True)
199+
p2 = self.preprocess_answer(s2, hard=True)
116200
if len(p1) != len(p2):
117201
return False
118202

119203
for a, b in zip(p1, p2):
120-
# быстрое сравнение чисел
204+
# быстрое числовое сравнение
121205
if self._compare_numeric(a, b):
122206
continue
123-
# пытаемся превратить в выражения
207+
208+
# пробуем превратить в sympy-выражения
124209
e1, e2 = self._to_expr(a), self._to_expr(b)
125210
if e1 is None or e2 is None:
211+
# ничего не разобрали → сравниваем строково
126212
if a != b:
127213
return False
128214
continue
215+
129216
if not self._expr_equal(e1, e2):
130217
return False
218+
131219
return True
132220

133221
# ──────────────────────────── public API ────────────────────────────
134-
def __call__(
135-
self, answer: str, predict: str
136-
) -> bool: # order: (правильный ответ, предикт)
222+
def __call__(self, answer: str, predict: str) -> bool:
223+
"""
224+
Возможные случаи:
225+
* обыкновенные дроби «3/4»
226+
* числа/научная нотация «1e-3»
227+
* символика / LaTeX / python-выражения / мультиответ
228+
Порядок «правильный ответ, предсказание» сохраняётся
229+
для симметрии с тестами.
230+
"""
137231
if not answer or not predict:
138232
return False
139-
answer, predict = self._normalize(answer), self._normalize(predict)
140233

234+
answer = self._normalize(answer)
235+
predict = self._normalize(predict)
236+
237+
# 1) дроби
141238
if self._compare_fraction(answer, predict):
142239
return True
143240

144-
if self.num_pattern.match(answer) and self.num_pattern.match(predict):
241+
# 2) обе строки выглядят как «простое число»
242+
if self.num_pattern.fullmatch(answer) and self.num_pattern.fullmatch(predict):
145243
if self._compare_numeric(answer, predict):
146244
return True
147-
# если равенство «почти» не прокатило – проверяем как выражения
245+
# если почти-равенство не прошло, проверяем как выражения
148246
return self.latex_equivalent(predict, answer)
149247

248+
# 3) общий тяжёлый случай
150249
return self.latex_equivalent(predict, answer)

0 commit comments

Comments
 (0)