2121# SOFTWARE.
2222
2323import re
24- from dataclasses import dataclass
24+ from dataclasses import dataclass , field
2525from functools import lru_cache
2626from itertools import groupby
27- from typing import Literal , Sequence
27+ from typing import Any , Literal , Sequence
2828
2929import sympy
3030from sympy import Basic , MatrixBase , Number
3939from lighteval .utils .timeout import timeout
4040
4141
42+ @requires_latex2sympy2_extended
43+ def latex_normalization_config_default_factory ():
44+ from latex2sympy2_extended .latex2sympy2 import NormalizationConfig
45+
46+ return NormalizationConfig (
47+ basic_latex = True ,
48+ units = True ,
49+ malformed_operators = True ,
50+ nits = True ,
51+ boxed = True ,
52+ equations = True ,
53+ )
54+
55+
4256@dataclass (frozen = True )
4357class LatexExtractionConfig :
4458 """Config for extracting latex from the prediction.
4559
4660 Attributes:
4761 try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is"
48- enforce_boxed_match (bool): Whether to also consider extracting from plain \b oxed{...} expressions
62+ boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...)
63+ normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction
4964 """
5065
5166 try_extract_without_anchor : bool = True
52- enforce_boxed_match : bool = True
67+ boxed_match_priority : int = 55
68+ normalization_config : Any = field (default_factory = latex_normalization_config_default_factory )
5369
5470
5571@dataclass (frozen = True )
@@ -187,9 +203,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) ->
187203 if latex_config .try_extract_without_anchor :
188204 regexes .append ((latex_re , 300 ))
189205
190- # This ensures that boxed is matched right after the final answer xxxx
191- if latex_config .enforce_boxed_match :
192- regexes .append ((latex_boxed , 55 ))
206+ if latex_config .boxed_match_priority >= 0 :
207+ regexes .append ((latex_boxed , latex_config .boxed_match_priority ))
193208
194209 return [(re .compile (pattern , re .DOTALL ), priority ) for pattern , priority in regexes ]
195210
@@ -387,6 +402,7 @@ def extract_target_from_pred(
387402 pred : str ,
388403 target_res : list [tuple [list [tuple [re .Pattern [str ], int ]], ExtractionTarget ]],
389404 fallback_mode : Literal ["no_fallback" , "first_match" ] = "no_fallback" ,
405+ extraction_mode : Literal ["first_match" , "any_match" ] = "any_match" ,
390406):
391407 """Extracts targets from a prediction string using regex patterns.
392408 Returns first sucesffuly extracted match.
@@ -397,6 +413,9 @@ def extract_target_from_pred(
397413 fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback".
398414 - "no_fallback": Return only successfully parsed match
399415 - "first_match": Additionaly Include the first string match no matter how parsing finished
416+ extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match".
417+ - "first_match": Only tries to extract the first match
418+ - "any_match": Tries to extract any match
400419
401420 Returns:
402421 list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match"
@@ -410,6 +429,7 @@ def extract_target_from_pred(
410429 for target_patterns , target_type in target_res
411430 for pattern , priority in target_patterns
412431 ]
432+ match_found = False
413433
414434 # Group patterns by priority using itertools.groupby
415435 for _ , patterns_group in groupby (sorted (all_patterns , key = lambda x : x [2 ]), key = lambda x : x [2 ]):
@@ -426,6 +446,7 @@ def extract_target_from_pred(
426446 # Try to extract from each match, starting from rightmost
427447 for match , _ , _ , target_type in matches_with_pos :
428448 extracted_match , str_fallback = extract_match (match , target_type )
449+ match_found = True
429450
430451 if str_fallback :
431452 fallbacks .append (str_fallback )
@@ -434,8 +455,11 @@ def extract_target_from_pred(
434455 extracted_predictions .append (extracted_match )
435456 break
436457
458+ if extraction_mode == "first_match" :
459+ break
460+
437461 # If we found something and we're in first_match mode, stop processing other priorities
438- if extracted_predictions :
462+ if extracted_predictions or ( match_found and extraction_mode == "first_match" ) :
439463 break
440464
441465 if fallback_mode == "first_match" and fallbacks :
0 commit comments