1+ from __future__ import annotations
2+
13import re
24from fractions import Fraction
5+ from typing import Sequence
6+
37import sympy
48from sympy .parsing .sympy_parser import parse_expr
59
6- # Если вдруг antlr4-python3-runtime не установлен, то не используем его
10+ # Если вдруг antlr4-python3-runtime не установлен, игнорируем latex-парсер
711try :
812 from sympy .parsing .latex import parse_latex # type: ignore
913
1014 _HAS_PARSE_LATEX = True
11- except Exception :
15+ except Exception : # pragma: no cover
1216 _HAS_PARSE_LATEX = False
1317
1418
1519class DoomSlayer :
20+ """
21+ Проверка эквивалентности ответов (строка-к-строке).
22+
23+ Главная идея ─ сначала пытаемся сравнить как числа,
24+ затем как дроби, затем как символьные выражения (sympy/LaTeX).
25+ """
26+
1627 def __init__ (self , EPS : float = 1e-2 ) -> None :
1728 self .EPS = EPS
18- self .num_pattern = re .compile (r"-?\d+(?:[.,]\d+)?$" )
29+
30+ # ────── основные шаблоны ──────
31+ self .num_pattern = re .compile (r"-?\d+(?:[.,]\d+)?(?:[eE]-?\d+)?$" )
1932 self .frac_pattern = re .compile (r"-?\d+\s*/\s*\d+$" )
20- # допускаем обрамления \( … \), $ … $, $$ … $$
33+ # допускаем обрамление \( … \) , $ … $ , $$ … $$
2134 self .latex_frac_pattern = re .compile (
2235 r"(?:\\\(|\$\$?)?\s*\\frac\{(-?\d+)\}\{(\d+)\}\s*(?:\\\)|\$\$?)?"
2336 )
37+ # различные варианты "минуса"
2438 self .minus_map = {"\u2212 " : "-" , "\u2013 " : "-" , "\u2014 " : "-" }
2539
2640 # ──────────────────────────── service ────────────────────────────
2741 def _strip_delims (self , s : str ) -> str :
42+ """Убираем окружающие $ … $, $$ … $$, \( … \), \[ … \]"""
2843 s = s .strip ()
2944 if s .startswith ("$$" ) and s .endswith ("$$" ):
3045 s = s [2 :- 2 ]
@@ -37,33 +52,44 @@ def _strip_delims(self, s: str) -> str:
3752 return s .strip ()
3853
3954 def _normalize (self , s : str ) -> str :
55+ """Замена длинных/узких минусов, обрезка пробелов"""
4056 for bad , good in self .minus_map .items ():
4157 s = s .replace (bad , good )
4258 return s .strip ()
4359
4460 # ──────────────────────────── preprocessing ────────────────────────────
45- def preprocess_answer (self , answer : str , hard : bool ):
61+ def preprocess_answer (self , answer : str , hard : bool ) -> list [str ]:
62+ """
63+ * soft-режим (hard = False) нужен только для _compare_numeric —
64+ вытягиваем **все** числа в строке;
65+ * hard-режим - для символьного сравнения —
66+ разбиваем строку по «;» на отдельные выражения,
67+ убираем внешние $$, \( \) и приводим **^ → \*\*** .
68+ """
4669 answer = self ._normalize (answer )
4770 answer = answer [:- 1 ] if answer .endswith ("." ) else answer
71+
4872 if not hard :
49- return re .findall (r"-?\d+(?:[.,]\d+)?" , answer )
73+ return re .findall (r"-?\d+(?:[.,]\d+)?(?:[eE]-?\d+)?" , answer )
74+
5075 return [
5176 self ._strip_delims (part ).lower ().replace ("**" , "^" ).strip ()
5277 for part in answer .split (";" )
5378 ]
5479
5580 # ──────────────────────────── helpers ────────────────────────────
5681 def _compare_numeric (self , a : str , b : str ) -> bool :
57- """Абс- и относительная погрешность"""
82+ """Абсолютная и относительная погрешность для чисел/научной нотации """
5883 try :
5984 fa , fb = float (a .replace ("," , "." )), float (b .replace ("," , "." ))
6085 except ValueError :
6186 return False
6287 diff = abs (fa - fb )
63- return diff <= self .EPS or diff / (abs (fb ) or 1 ) <= self .EPS
88+ return diff <= self .EPS or diff / (abs (fb ) or 1.0 ) <= self .EPS
6489
6590 def _compare_fraction (self , s1 : str , s2 : str ) -> bool :
66- def to_frac (s : str ):
91+ """Сравнение обыкновенных дробей (в т. ч. LaTeX)"""
92+ def to_frac (s : str ) -> Fraction | None :
6793 s = self ._strip_delims (s )
6894 if self .frac_pattern .fullmatch (s ):
6995 num , den = map (int , s .split ("/" ))
@@ -80,71 +106,144 @@ def to_frac(s: str):
80106 return False
81107
82108 def _to_expr (self , s : str ):
83- """Пытаемся превратить строку в sympy-выражение максимально надёжно"""
84- # 1) голое число
109+ """
110+ Превращаем строку в sympy-объект (Float | Expr | Tuple | Matrix …).
111+ Пытаемся по очереди:
112+ 1) число → sympy.Float
113+ 2) python-математика («2^3» → «2**3») → parse_expr
114+ 3) LaTeX (если доступен parse_latex)
115+ """
116+ s_clean = s .replace ("," , "." ) # «1,2» → «1.2»
85117 try :
86- return sympy .Float (s . replace ( "," , "." ) )
118+ return sympy .Float (s_clean )
87119 except Exception :
88120 pass
89- # 2) обычная «python-математика»
121+
90122 try :
91- return parse_expr (s .replace ("^" , "**" ), evaluate = True )
123+ return parse_expr (s_clean .replace ("^" , "**" ), evaluate = True )
92124 except Exception :
93125 pass
94- # 3) LaTeX (если библиотека доступна)
126+
95127 if _HAS_PARSE_LATEX :
96128 try :
97- return parse_latex (self ._strip_delims (s ))
129+ return parse_latex (self ._strip_delims (s_clean ))
98130 except Exception :
99131 pass
100- return None
132+
133+ return None # ничего не получилось
134+
135+ # ──────────────────────────── expression equality ─────────────────────────
136+ def _iterable_equal (
137+ self ,
138+ seq1 : Sequence ,
139+ seq2 : Sequence ,
140+ ) -> bool :
141+ """Рекурсивное покомпонентное сравнение кортежей / списков / матриц"""
142+ if len (seq1 ) != len (seq2 ):
143+ return False
144+ return all (self ._expr_equal (a , b ) for a , b in zip (seq1 , seq2 ))
101145
102146 def _expr_equal (self , e1 , e2 ) -> bool :
103- diff = sympy .simplify (e1 - e2 )
147+ """
148+ Надёжное сравнение любых sympy-объектов:
149+ * скаляры (Float, Integer, Symbol …),
150+ * кортежи (sympy.Tuple / обычный tuple),
151+ * матрицы (sympy.MatrixBase).
152+ """
153+ # ─── кортежи / списки ───
154+ if isinstance (e1 , (tuple , sympy .Tuple )) or isinstance (
155+ e2 , (tuple , sympy .Tuple )
156+ ):
157+ if not isinstance (e1 , (tuple , sympy .Tuple )) or not isinstance (
158+ e2 , (tuple , sympy .Tuple )
159+ ):
160+ return False
161+ return self ._iterable_equal (e1 , e2 )
162+
163+ # ─── матрицы ───
164+ from sympy .matrices .matrices import MatrixBase # локальный импорт → быстрее старт
165+ if isinstance (e1 , MatrixBase ) or isinstance (e2 , MatrixBase ): # pragma: no branch
166+ if not (isinstance (e1 , MatrixBase ) and isinstance (e2 , MatrixBase )):
167+ return False
168+ if e1 .shape != e2 .shape :
169+ return False
170+ return self ._iterable_equal (tuple (e1 ), tuple (e2 ))
171+
172+ # ─── обычные выражения ───
173+ try :
174+ diff = sympy .simplify (e1 - e2 )
175+ except TypeError :
176+ # например, «tuple - Float» — значит несопоставимые типы
177+ return False
178+
104179 if diff == 0 :
105180 return True # точное равенство
106- if diff .free_symbols : # осталось x, y … — считаем неравными
181+
182+ if diff .free_symbols : # остались x, y … → не удалось упростить
107183 return False
184+
108185 try :
109186 return abs (float (diff )) <= self .EPS
110- except Exception :
187+ except Exception : # pragma: no cover
111188 return False
112189
113190 # ──────────────────────────── core ────────────────────────────
114191 def latex_equivalent (self , s1 : str , s2 : str ) -> bool :
115- p1 , p2 = self .preprocess_answer (s1 , True ), self .preprocess_answer (s2 , True )
192+ """
193+ «Тяжёлое» сравнение:
194+ * разбиваем ответы по «;» (мультиответ),
195+ * сравниваем соответствующие пары.
196+ Порядок элементов **важен**.
197+ """
198+ p1 = self .preprocess_answer (s1 , hard = True )
199+ p2 = self .preprocess_answer (s2 , hard = True )
116200 if len (p1 ) != len (p2 ):
117201 return False
118202
119203 for a , b in zip (p1 , p2 ):
120- # быстрое сравнение чисел
204+ # быстрое числовое сравнение
121205 if self ._compare_numeric (a , b ):
122206 continue
123- # пытаемся превратить в выражения
207+
208+ # пробуем превратить в sympy-выражения
124209 e1 , e2 = self ._to_expr (a ), self ._to_expr (b )
125210 if e1 is None or e2 is None :
211+ # ничего не разобрали → сравниваем строково
126212 if a != b :
127213 return False
128214 continue
215+
129216 if not self ._expr_equal (e1 , e2 ):
130217 return False
218+
131219 return True
132220
133221 # ──────────────────────────── public API ────────────────────────────
134- def __call__ (
135- self , answer : str , predict : str
136- ) -> bool : # order: (правильный ответ, предикт)
222+ def __call__ (self , answer : str , predict : str ) -> bool :
223+ """
224+ Возможные случаи:
225+ * обыкновенные дроби «3/4»
226+ * числа/научная нотация «1e-3»
227+ * символика / LaTeX / python-выражения / мультиответ
228+ Порядок «правильный ответ, предсказание» сохраняётся
229+ для симметрии с тестами.
230+ """
137231 if not answer or not predict :
138232 return False
139- answer , predict = self ._normalize (answer ), self ._normalize (predict )
140233
234+ answer = self ._normalize (answer )
235+ predict = self ._normalize (predict )
236+
237+ # 1) дроби
141238 if self ._compare_fraction (answer , predict ):
142239 return True
143240
144- if self .num_pattern .match (answer ) and self .num_pattern .match (predict ):
241+ # 2) обе строки выглядят как «простое число»
242+ if self .num_pattern .fullmatch (answer ) and self .num_pattern .fullmatch (predict ):
145243 if self ._compare_numeric (answer , predict ):
146244 return True
147- # если равенство «почти» не прокатило – проверяем как выражения
245+ # если почти- равенство не прошло, проверяем как выражения
148246 return self .latex_equivalent (predict , answer )
149247
248+ # 3) общий тяжёлый случай
150249 return self .latex_equivalent (predict , answer )
0 commit comments