1+ from dataflow .core import TextFilter
2+ import numpy as np
3+ from dataflow .utils .registry import PROCESSOR_REGISTRY
4+ #from math_verify import parse, verify, LatexExtractionConfig
5+ import pandas as pd
6+ from tqdm import tqdm
7+ import logging
8+ import re
9+ from word2number import w2n
10+
11+ # Helper Class for String Processing
12+ class StringProcessor :
13+ """
14+ A class that encapsulates various string processing functions for mathematical expressions.
15+ """
16+
17+ @staticmethod
18+ def _fix_fracs (string ):
19+ """
20+ Fixes fraction expressions in the string, ensuring they are properly formatted as \f rac{a}{b}.
21+ """
22+ substrs = string .split ("\\ frac" )
23+ new_str = substrs [0 ]
24+ if len (substrs ) > 1 :
25+ for substr in substrs [1 :]:
26+ new_str += "\\ frac"
27+ if len (substr ) > 0 and substr [0 ] == "{" :
28+ new_str += substr
29+ else :
30+ if len (substr ) >= 2 :
31+ a , b = substr [0 ], substr [1 ]
32+ if b != "{" :
33+ new_str += f"{{{ a } }}{{{ b } }}{ substr [2 :]} " if len (substr ) > 2 else f"{{{ a } }}{{{ b } }}"
34+ else :
35+ new_str += f"{{{ a } }}{ b } { substr [2 :]} " if len (substr ) > 2 else f"{{{ a } }}{ b } "
36+ else :
37+ return string
38+ return new_str
39+
40+ @staticmethod
41+ def _fix_a_slash_b (string ):
42+ """
43+ Fixes cases where a fraction is represented as a simple division (e.g., a/b) and converts it to \f rac{a}{b}.
44+ """
45+ if len (string .split ("/" )) != 2 :
46+ return string
47+ a , b = string .split ("/" )
48+ try :
49+ a , b = int (a ) if "sqrt" not in a else a , int (b ) if "sqrt" not in b else b
50+ assert string == f"{ a } /{ b } "
51+ return f"\\ frac{{{ a } }}{{{ b } }}"
52+ except :
53+ return string
54+
55+ @staticmethod
56+ def _fix_sqrt (string ):
57+ """
58+ Ensures that square root expressions are properly formatted as \sqrt{...}.
59+ """
60+ return re .sub (r"\\sqrt(\w+)" , r"\\sqrt{\1}" , string )
61+
62+ @staticmethod
63+ def convert_word_number (text : str ) -> str :
64+ """
65+ Converts a word representation of a number to a digit.
66+ """
67+ try :
68+ return str (w2n .word_to_num (text ))
69+ except :
70+ return text
71+
72+
73+ # Unit Text Class to Manage Unit Texts
74+ class UnitTextManager :
75+ """
76+ A class that encapsulates unit text management to remove unwanted unit terms from strings.
77+ """
78+
79+ def __init__ (self ):
80+ """
81+ Initializes the unit texts and their plural forms.
82+ """
83+ self .unit_texts = [
84+ "east" , "degree" , "mph" , "kmph" , "ft" , "m sqaure" , "m east" , "sq m" , "deg" , "mile" , "q ." , "monkey" , "prime" ,
85+ "ratio" , "profit of rs" , "rd" , "o" , "gm" , "p . m" , "lb" , "tile" , "per" , "dm" , "lt" , "gain" , "ab" , "way" , "west" ,
86+ "a ." , "b ." , "c ." , "d ." , "e ." , "f ." , "g ." , "h ." , "t" , "a" , "h" , "no change" , "men" , "soldier" , "pie" , "bc" ,
87+ "excess" , "st" , "inches" , "noon" , "percent" , "by" , "gal" , "kmh" , "c" , "acre" , "rise" , "a . m" , "th" , "π r 2" , "sq" ,
88+ "mark" , "l" , "toy" , "coin" , "sq . m" , "gallon" , "° f" , "profit" , "minw" , "yr" , "women" , "feet" , "am" , "pm" , "hr" ,
89+ "cu cm" , "square" , "v â € ™" , "are" , "rupee" , "rounds" , "cubic" , "cc" , "mtr" , "s" , "ohm" , "number" , "kmph" , "day" ,
90+ "hour" , "minute" , "min" , "second" , "man" , "woman" , "sec" , "cube" , "mt" , "sq inch" , "mp" , "∏ cm ³" , "hectare" ,
91+ "more" , "sec" , "unit" , "cu . m" , "cm 2" , "rs ." , "rs" , "kg" , "g" , "month" , "km" , "m" , "cm" , "mm" , "apple" , "liter" ,
92+ "loss" , "yard" , "pure" , "year" , "increase" , "decrease" , "d" , "less" , "Surface" , "litre" , "pi sq m" , "s ." , "metre" ,
93+ "meter" , "inch" ,
94+ ]
95+ self .unit_texts .extend ([t + "s" for t in self .unit_texts ])
96+
97+ def clean_units (self , string : str ):
98+ """
99+ Cleans the string by removing unit terms from it.
100+ """
101+ for unit_text in self .unit_texts :
102+ string = re .sub (r"(^|\W)" + unit_text + r"($|\W)" , r"\1\2" , string )
103+ return string
104+
105+
106+ # Main String Processing Class
107+ class StringCleaner :
108+ """
109+ A class responsible for cleaning and formatting strings in mathematical expressions.
110+ """
111+
112+ def __init__ (self , unit_manager : UnitTextManager ):
113+ """
114+ Initializes the StringCleaner class with a unit manager.
115+ """
116+ self .unit_manager = unit_manager
117+
118+ def strip_string (self , string , skip_unit = False ):
119+ """
120+ Strips unwanted characters and units from the string.
121+ """
122+ string = str (string ).strip ().replace ("\n " , "" ).rstrip ("." ).replace ("\\ !" , "" )
123+ string = re .sub (r"\\begin\{array\}\{.*?\}" , r"\\begin{pmatrix}" , string )
124+ string = re .sub (r"\\end\{array\}" , r"\\end{pmatrix}" , string ).replace ("bmatrix" , "pmatrix" )
125+ string = string .replace ("tfrac" , "frac" ).replace ("dfrac" , "frac" ).replace ("\\ neq" , "\\ ne" ).replace ("\\ leq" , "\\ le" ).replace ("\\ geq" , "\\ ge" )
126+ string = string .replace ("\\ left" , "" ).replace ("\\ right" , "" ).replace ("\\ {" , "{" ).replace ("\\ }" , "}" )
127+
128+ # Clean unit texts if needed
129+ if not skip_unit :
130+ string = self .unit_manager .clean_units (string )
131+
132+ string = string .replace ("^{\\ circ}" , "" ).replace ("^\\ circ" , "" ).replace ("\\ $" , "" ).replace ("$" , "" ).replace ("\\ (" , "" ).replace ("\\ )" , "" )
133+ string = StringProcessor .convert_word_number (string )
134+ string = re .sub (r"\\text\{(.*?)\}" , r"\1" , string )
135+
136+ for key in ["x=" , "y=" , "z=" , "x\\ in" , "y\\ in" , "z\\ in" , "x\\ to" , "y\\ to" , "z\\ to" ]:
137+ string = string .replace (key , "" )
138+
139+ string = string .replace ("\\ emptyset" , r"{}" ).replace ("(-\\ infty,\\ infty)" , "\\ mathbb{R}" )
140+ string = string .replace ("%" , "" ).replace (" ." , " 0." ).replace ("{." , "{0." )
141+
142+ return string
143+
144+
145+ # Core Answer Extraction Logic Class
146+ class AnswerExtractor :
147+ """
148+ A class responsible for extracting the final answer from a prediction string.
149+ """
150+
151+ def __init__ (self , string_cleaner : StringCleaner ):
152+ """
153+ Initializes the AnswerExtractor class with a string cleaner.
154+ """
155+ self .string_cleaner = string_cleaner
156+
157+ def extract_answer (self , pred_str , data_name , use_last_number = True ):
158+ """
159+ Extracts the final answer from the prediction string, processing various formats.
160+ """
161+ pred_str = pred_str .replace ("\u043a \u0438 " , "" )
162+
163+ # Handle special cases based on data_name or pattern
164+ if "final answer is $" in pred_str and "$. I hope" in pred_str :
165+ pred = pred_str .split ("final answer is $" , 1 )[1 ].split ("$. I hope" , 1 )[0 ].strip ()
166+ elif "boxed" in pred_str :
167+ pred = self ._extract_boxed_answer (pred_str )
168+ elif "he answer is" in pred_str :
169+ pred = pred_str .split ("he answer is" )[- 1 ].strip ()
170+ else :
171+ pred = self ._get_last_number_answer (pred_str , use_last_number )
172+
173+ pred = self .string_cleaner .strip_string (pred , skip_unit = data_name in ["carp_en" , "minerva_math" ])
174+ return pred
175+
176+ def _extract_boxed_answer (self , pred_str ):
177+ """
178+ Extracts answers enclosed in 'boxed' notation.
179+ """
180+ ans = pred_str .split ("boxed" )[- 1 ]
181+ if ans .startswith ("{" ):
182+ return self ._extract_bracketed_answer (ans )
183+ else :
184+ return ans .split ("$" )[0 ].strip ()
185+
186+ def _extract_bracketed_answer (self , ans ):
187+ """
188+ Handles answers that are enclosed within brackets.
189+ """
190+ stack = 1
191+ result = ""
192+ for c in ans [1 :]:
193+ if c == "{" :
194+ stack += 1
195+ result += c
196+ elif c == "}" :
197+ stack -= 1
198+ if stack == 0 :
199+ break
200+ result += c
201+ else :
202+ result += c
203+ return result
204+
205+ def _get_last_number_answer (self , pred_str , use_last_number ):
206+ """
207+ Extracts the last number from the string if use_last_number is True.
208+ """
209+ if use_last_number :
210+ pattern = "-?\d*\.?\d+"
211+ pred = re .findall (pattern , pred_str .replace ("," , "" ))
212+ return pred [- 1 ] if pred else ""
213+ return ""
214+
215+
216+ @PROCESSOR_REGISTRY .register ()
217+ class AnswerGroundTruthFilter (TextFilter ):
218+ def __init__ (self , args_dict : dict ):
219+ super ().__init__ (args_dict )
220+ self .filter_name = 'AnswerGroundTruthFilter'
221+ unit_manager = UnitTextManager ()
222+ string_cleaner = StringCleaner (unit_manager )
223+ self .answer_extractor = AnswerExtractor (string_cleaner )
224+
225+ def filter_func (self , dataset ):
226+ indexes = np .zeros (len (dataset )).astype (int )
227+ for i in range (len (dataset )):
228+ final_answer = self .answer_extractor .extract_answer (dataset [i ]['answer' ], dataset [i ].get ('data_name' , None ))
229+ if 'ground_truth_answer' in dataset [i ] and final_answer == dataset [i ]['ground_truth_answer' ]:
230+ indexes [i ] = 1
231+ return indexes
0 commit comments