|
1 | | -import re |
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +import json |
2 | 4 | import sympy |
3 | | -from typing import List, Union |
4 | 5 | from sympy.parsing.latex import parse_latex |
| 6 | +import os |
| 7 | +import re |
| 8 | +from fractions import Fraction |
5 | 9 |
|
6 | 10 |
|
7 | 11 | class DoomSlayer: |
8 | | - """ |
9 | | - Класс для проверки эквивалентности математических выражений и ответов. |
10 | | -
|
11 | | - Используется для оценки правильности ответов моделей, сравнивая их с эталонными ответами. |
12 | | - Поддерживает различные форматы ответов, включая числовые и LaTeX-выражения. |
13 | | - """ |
14 | | - |
15 | | - def __init__(self, EPS: float = 1e-2): |
16 | | - """ |
17 | | - Инициализирует проверщик математических выражений. |
18 | | -
|
19 | | - Args: |
20 | | - EPS: Допустимая погрешность при сравнении числовых значений |
21 | | - """ |
| 12 | + def __init__(self, EPS=1e-2): |
22 | 13 | self.EPS = EPS |
23 | | - |
24 | | - def preprocess_answer(self, answer: str, hard: bool) -> Union[List[str], List]: |
25 | | - """ |
26 | | - Предобрабатывает ответ для последующего сравнения. |
27 | | -
|
28 | | - Args: |
29 | | - answer: Строка с ответом |
30 | | - hard: Флаг режима проверки (True для сложной проверки, False для простой) |
31 | | -
|
32 | | - Returns: |
33 | | - Предобработанная строка или список строк |
34 | | - """ |
| 14 | + self.num_pattern = re.compile(r"-?\d+(?:[.,]\d+)?$") |
| 15 | + self.frac_pattern = re.compile(r"-?\d+\s*/\s*\d+$") |
| 16 | + self.latex_frac_pattern = re.compile(r"\\frac\{(-?\d+)\}\{(\d+)\}") |
| 17 | + self.minus_map = {"\u2212": "-", "\u2013": "-", "\u2014": "-"} # ADDED |
| 18 | + |
| 19 | + def _normalize(self, s: str) -> str: |
| 20 | + for uni, ascii_minus in self.minus_map.items(): # ADDED |
| 21 | + s = s.replace(uni, ascii_minus) # ADDED |
| 22 | + return s.strip() # ADDED |
| 23 | + |
| 24 | + def preprocess_answer(self, answer: str, hard: bool): |
| 25 | + answer = self._normalize(answer) # ADDED |
| 26 | + answer = answer[:-1] if answer.endswith(".") else answer |
35 | 27 | if not hard: |
36 | | - return re.findall("[0-9.]+", answer) |
37 | | - answer = answer.lower().replace("**", "^").split(";") |
38 | | - return answer |
| 28 | + return re.findall(r"-?\d+(?:[.,]\d+)?", answer) |
| 29 | + return answer.lower().replace("**", "^").split(";") |
| 30 | + |
| 31 | + def __call__(self, answer: str, predict: str) -> bool: |
| 32 | + if not answer or not predict: |
| 33 | + return False |
| 34 | + answer = self._normalize(answer) # ADDED |
| 35 | + predict = self._normalize(predict) # ADDED |
39 | 36 |
|
40 | | - def __call__(self, predict: str, answer: str) -> bool: |
41 | | - """ |
42 | | - Проверяет эквивалентность предсказанного и правильного ответов. |
| 37 | + if self._compare_fraction(answer, predict): |
| 38 | + return True |
43 | 39 |
|
44 | | - Args: |
45 | | - predict: Предсказанный ответ |
46 | | - answer: Правильный ответ |
| 40 | + if self.num_pattern.match(answer) and self.num_pattern.match(predict): |
| 41 | + return self.simple_check(predict, answer) or self.latex_equivalent( |
| 42 | + predict, answer |
| 43 | + ) |
47 | 44 |
|
48 | | - Returns: |
49 | | - True если ответы эквивалентны, иначе False |
50 | | - """ |
51 | | - if not re.match("[0-9., ]+", answer) or not re.match("[0-9,. ]+", predict): |
52 | | - return self.latex_equivalent(predict, answer) |
53 | | - if ( |
54 | | - re.match("[0-9., ]+", answer)[0] == answer |
55 | | - and re.match("[0-9,. ]+", predict)[0] == predict |
56 | | - ): |
57 | | - return self.simple_check(predict, answer) |
58 | 45 | return self.latex_equivalent(predict, answer) |
59 | 46 |
|
60 | 47 | def simple_check(self, predict: str, answer: str) -> bool: |
61 | | - """ |
62 | | - Выполняет простую проверку числовых ответов. |
63 | | -
|
64 | | - Args: |
65 | | - predict: Предсказанный ответ |
66 | | - answer: Правильный ответ |
67 | | -
|
68 | | - Returns: |
69 | | - True если ответы эквивалентны, иначе False |
70 | | - """ |
71 | | - predict = self.preprocess_answer(predict, False) |
72 | | - answer = self.preprocess_answer(answer, False) |
73 | | - return "".join(answer).replace(",", ".") == "".join(predict).replace(",", ".") |
74 | | - |
75 | | - def latex_equivalent(self, latex_formula1: str, latex_formula2: str) -> bool: |
76 | | - """ |
77 | | - Сравнивает две формулы в формате LaTeX через Sympy. |
78 | | -
|
79 | | - Args: |
80 | | - latex_formula1: Первая формула в формате LaTeX |
81 | | - latex_formula2: Вторая формула в формате LaTeX |
82 | | -
|
83 | | - Returns: |
84 | | - True если формулы математически эквивалентны, иначе False |
85 | | - """ |
86 | | - latex_formula1 = self.preprocess_answer(latex_formula1, True) |
87 | | - latex_formula2 = self.preprocess_answer(latex_formula2, True) |
88 | | - |
89 | | - results = [True for _ in range(len(latex_formula1))] |
90 | | - for i in range(len(latex_formula1)): |
| 48 | + p = self.preprocess_answer(predict, False) |
| 49 | + a = self.preprocess_answer(answer, False) |
| 50 | + return "".join(a).replace(",", ".") == "".join(p).replace(",", ".") |
| 51 | + |
| 52 | + def _compare_fraction(self, s1: str, s2: str) -> bool: |
| 53 | + def to_frac(s): |
| 54 | + s = s.strip() |
| 55 | + m = self.frac_pattern.fullmatch(s) |
| 56 | + if m: |
| 57 | + num, den = map(int, s.split("/")) |
| 58 | + return Fraction(num, den) |
| 59 | + m2 = self.latex_frac_pattern.fullmatch(s) |
| 60 | + if m2: |
| 61 | + num, den = map(int, m2.groups()) |
| 62 | + return Fraction(num, den) |
| 63 | + return None |
| 64 | + |
| 65 | + f1 = to_frac(s1) |
| 66 | + f2 = to_frac(s2) |
| 67 | + if f1 is not None and f2 is not None: |
| 68 | + return abs(float(f1) - float(f2)) <= self.EPS |
| 69 | + return False |
| 70 | + |
| 71 | + def latex_equivalent(self, latex1: str, latex2: str) -> bool: |
| 72 | + parts1 = self.preprocess_answer(latex1, True) |
| 73 | + parts2 = self.preprocess_answer(latex2, True) |
| 74 | + if len(parts1) != len(parts2): |
| 75 | + return False |
| 76 | + |
| 77 | + for a, b in zip(parts1, parts2): |
91 | 78 | try: |
92 | | - expr1 = parse_latex(latex_formula1[i]) |
93 | | - expr2 = parse_latex(latex_formula2[i]) |
94 | | - diff = sympy.simplify(expr1 - expr2) |
95 | | - results[i] = diff <= self.EPS |
96 | | - except Exception: |
| 79 | + e1 = parse_latex(a) |
| 80 | + e2 = parse_latex(b) |
| 81 | + diff = sympy.simplify(abs(e1 - e2)) |
97 | 82 | try: |
98 | | - if latex_formula1[i] == latex_formula2[i]: |
99 | | - continue |
100 | | - except: |
| 83 | + diff_rel = sympy.simplify(abs(e1 - e2) / abs(e2)) |
| 84 | + diff = min(diff, diff_rel) |
| 85 | + except Exception: |
101 | 86 | pass |
102 | | - results[i] = False |
103 | | - return all(results) |
| 87 | + if diff > self.EPS: |
| 88 | + return False |
| 89 | + except Exception: |
| 90 | + if a != b: |
| 91 | + return False |
| 92 | + |
| 93 | + return True |
0 commit comments