Skip to content

Commit bfd9748

Browse files
committed
Refactor answer extraction logic to handle multiple matches in response text
1 parent 58ac0b6 commit bfd9748

File tree

2 files changed

+83
-93
lines changed

2 files changed

+83
-93
lines changed

src/equality_checker.py

Lines changed: 77 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,93 @@
1-
import re
1+
import numpy as np
2+
import pandas as pd
3+
import json
24
import sympy
3-
from typing import List, Union
45
from sympy.parsing.latex import parse_latex
6+
import os
7+
import re
8+
from fractions import Fraction
59

610

711
class DoomSlayer:
8-
"""
9-
Класс для проверки эквивалентности математических выражений и ответов.
10-
11-
Используется для оценки правильности ответов моделей, сравнивая их с эталонными ответами.
12-
Поддерживает различные форматы ответов, включая числовые и LaTeX-выражения.
13-
"""
14-
15-
def __init__(self, EPS: float = 1e-2):
16-
"""
17-
Инициализирует проверщик математических выражений.
18-
19-
Args:
20-
EPS: Допустимая погрешность при сравнении числовых значений
21-
"""
12+
def __init__(self, EPS=1e-2):
2213
self.EPS = EPS
23-
24-
def preprocess_answer(self, answer: str, hard: bool) -> Union[List[str], List]:
25-
"""
26-
Предобрабатывает ответ для последующего сравнения.
27-
28-
Args:
29-
answer: Строка с ответом
30-
hard: Флаг режима проверки (True для сложной проверки, False для простой)
31-
32-
Returns:
33-
Предобработанная строка или список строк
34-
"""
14+
self.num_pattern = re.compile(r"-?\d+(?:[.,]\d+)?$")
15+
self.frac_pattern = re.compile(r"-?\d+\s*/\s*\d+$")
16+
self.latex_frac_pattern = re.compile(r"\\frac\{(-?\d+)\}\{(\d+)\}")
17+
self.minus_map = {"\u2212": "-", "\u2013": "-", "\u2014": "-"} # ADDED
18+
19+
def _normalize(self, s: str) -> str:
20+
for uni, ascii_minus in self.minus_map.items(): # ADDED
21+
s = s.replace(uni, ascii_minus) # ADDED
22+
return s.strip() # ADDED
23+
24+
def preprocess_answer(self, answer: str, hard: bool):
25+
answer = self._normalize(answer) # ADDED
26+
answer = answer[:-1] if answer.endswith(".") else answer
3527
if not hard:
36-
return re.findall("[0-9.]+", answer)
37-
answer = answer.lower().replace("**", "^").split(";")
38-
return answer
28+
return re.findall(r"-?\d+(?:[.,]\d+)?", answer)
29+
return answer.lower().replace("**", "^").split(";")
30+
31+
def __call__(self, answer: str, predict: str) -> bool:
32+
if not answer or not predict:
33+
return False
34+
answer = self._normalize(answer) # ADDED
35+
predict = self._normalize(predict) # ADDED
3936

40-
def __call__(self, predict: str, answer: str) -> bool:
41-
"""
42-
Проверяет эквивалентность предсказанного и правильного ответов.
37+
if self._compare_fraction(answer, predict):
38+
return True
4339

44-
Args:
45-
predict: Предсказанный ответ
46-
answer: Правильный ответ
40+
if self.num_pattern.match(answer) and self.num_pattern.match(predict):
41+
return self.simple_check(predict, answer) or self.latex_equivalent(
42+
predict, answer
43+
)
4744

48-
Returns:
49-
True если ответы эквивалентны, иначе False
50-
"""
51-
if not re.match("[0-9., ]+", answer) or not re.match("[0-9,. ]+", predict):
52-
return self.latex_equivalent(predict, answer)
53-
if (
54-
re.match("[0-9., ]+", answer)[0] == answer
55-
and re.match("[0-9,. ]+", predict)[0] == predict
56-
):
57-
return self.simple_check(predict, answer)
5845
return self.latex_equivalent(predict, answer)
5946

6047
def simple_check(self, predict: str, answer: str) -> bool:
61-
"""
62-
Выполняет простую проверку числовых ответов.
63-
64-
Args:
65-
predict: Предсказанный ответ
66-
answer: Правильный ответ
67-
68-
Returns:
69-
True если ответы эквивалентны, иначе False
70-
"""
71-
predict = self.preprocess_answer(predict, False)
72-
answer = self.preprocess_answer(answer, False)
73-
return "".join(answer).replace(",", ".") == "".join(predict).replace(",", ".")
74-
75-
def latex_equivalent(self, latex_formula1: str, latex_formula2: str) -> bool:
76-
"""
77-
Сравнивает две формулы в формате LaTeX через Sympy.
78-
79-
Args:
80-
latex_formula1: Первая формула в формате LaTeX
81-
latex_formula2: Вторая формула в формате LaTeX
82-
83-
Returns:
84-
True если формулы математически эквивалентны, иначе False
85-
"""
86-
latex_formula1 = self.preprocess_answer(latex_formula1, True)
87-
latex_formula2 = self.preprocess_answer(latex_formula2, True)
88-
89-
results = [True for _ in range(len(latex_formula1))]
90-
for i in range(len(latex_formula1)):
48+
p = self.preprocess_answer(predict, False)
49+
a = self.preprocess_answer(answer, False)
50+
return "".join(a).replace(",", ".") == "".join(p).replace(",", ".")
51+
52+
def _compare_fraction(self, s1: str, s2: str) -> bool:
53+
def to_frac(s):
54+
s = s.strip()
55+
m = self.frac_pattern.fullmatch(s)
56+
if m:
57+
num, den = map(int, s.split("/"))
58+
return Fraction(num, den)
59+
m2 = self.latex_frac_pattern.fullmatch(s)
60+
if m2:
61+
num, den = map(int, m2.groups())
62+
return Fraction(num, den)
63+
return None
64+
65+
f1 = to_frac(s1)
66+
f2 = to_frac(s2)
67+
if f1 is not None and f2 is not None:
68+
return abs(float(f1) - float(f2)) <= self.EPS
69+
return False
70+
71+
def latex_equivalent(self, latex1: str, latex2: str) -> bool:
72+
parts1 = self.preprocess_answer(latex1, True)
73+
parts2 = self.preprocess_answer(latex2, True)
74+
if len(parts1) != len(parts2):
75+
return False
76+
77+
for a, b in zip(parts1, parts2):
9178
try:
92-
expr1 = parse_latex(latex_formula1[i])
93-
expr2 = parse_latex(latex_formula2[i])
94-
diff = sympy.simplify(expr1 - expr2)
95-
results[i] = diff <= self.EPS
96-
except Exception:
79+
e1 = parse_latex(a)
80+
e2 = parse_latex(b)
81+
diff = sympy.simplify(abs(e1 - e2))
9782
try:
98-
if latex_formula1[i] == latex_formula2[i]:
99-
continue
100-
except:
83+
diff_rel = sympy.simplify(abs(e1 - e2) / abs(e2))
84+
diff = min(diff, diff_rel)
85+
except Exception:
10186
pass
102-
results[i] = False
103-
return all(results)
87+
if diff > self.EPS:
88+
return False
89+
except Exception:
90+
if a != b:
91+
return False
92+
93+
return True

src/mat_boy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def fn(row: Dict[str, str]) -> SingleEvalResult:
9696
response_text, metadata = sampler(prompt_messages, return_metadata=True)
9797

9898
answer_pattern = r"(?:Answer|Ответ):\s*(.+)$"
99-
match = re.search(answer_pattern, response_text, re.MULTILINE)
100-
extracted_answer = match.group(1).strip() if match else None
99+
matches = list(re.finditer(answer_pattern, response_text, re.MULTILINE))
100+
extracted_answer = matches[-1].group(1).strip() if matches else None
101101

102102
if self.debug:
103103
print(f"Extracted answer: {extracted_answer}")
@@ -210,8 +210,8 @@ def fn(row: Dict[str, str]) -> SingleEvalResult:
210210
response_text, metadata = sampler(prompt_messages, return_metadata=True)
211211

212212
answer_pattern = r"(?:Answer|Ответ):\s*(.+)$"
213-
match = re.search(answer_pattern, response_text, re.MULTILINE)
214-
extracted_answer = match.group(1).strip() if match else None
213+
matches = list(re.finditer(answer_pattern, response_text, re.MULTILINE))
214+
extracted_answer = matches[-1].group(1).strip() if matches else None
215215

216216
if self.debug:
217217
print(f"Extracted answer: {extracted_answer}")
@@ -329,8 +329,8 @@ def fn(row: Dict[str, str]) -> SingleEvalResult:
329329
response_text, metadata = sampler(prompt_messages, return_metadata=True)
330330

331331
answer_pattern = r"(?:Answer|Ответ):\s*(.+)$"
332-
match = re.search(answer_pattern, response_text, re.MULTILINE)
333-
extracted_answer = match.group(1).strip() if match else None
332+
matches = list(re.finditer(answer_pattern, response_text, re.MULTILINE))
333+
extracted_answer = matches[-1].group(1).strip() if matches else None
334334

335335
if self.debug:
336336
print(f"Extracted answer: {extracted_answer}")

0 commit comments

Comments
 (0)