@@ -17,62 +17,67 @@ def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1):
1717 self .tolerance = tolerance
1818 self .partial_credit = partial_credit
1919
20- def _to_float (self , text ) -> Optional [float ]:
21- """Safely parse a string into a float, or return None if invalid."""
22- if text is None :
23- return None
24- try :
25- return float (str (text ).strip ())
26- except (ValueError , TypeError ):
27- return None
28-
29- def _extract_number (self , text : str ) -> Optional [float ]:
30- """Try to extract a numeric answer from text."""
31- number_pattern = r"([+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)"
32- patterns = [
33- r"####\s*" + number_pattern ,
34- r"(?:the\s+)?answer\s+is\s*" + number_pattern ,
35- r"(?:answer:|result:)\s*" + number_pattern ,
36- r"\$" + number_pattern , # currency
37- number_pattern , # fallback
38- r"=\s*" + number_pattern + r"\s*(?:\.|$)" ,
39- r"\b" + number_pattern + r"\s*(?:\.|$)" ,
40- ]
41- text = text .lower ().strip ()
42- for pattern in patterns :
43- matches = re .findall (pattern , text )
44- if matches :
45- return self ._to_float (matches [- 1 ])
46- return None
47-
4820 def __call__ (self , prompt : str , response : str , target : str ) -> float :
4921 """Compute math correctness reward."""
50- # Parse expected
51- expected_answer = self ._to_float (target )
22+ target_number = self ._to_float (target )
23+ if target_number is None :
24+ return 0.0
5225
53- # Parse response
54- model_answer = self . _extract_number ( response )
26+ # Look for answer in <answer></answer> tags
27+ answer_match = re . search ( r"<answer>(.*?)</answer>" , response , re . DOTALL )
5528
56- # Scoring
57- if expected_answer is None or model_answer is None :
58- return self .partial_credit # Partial credit for attempting
29+ if answer_match :
30+ model_answer = self ._to_float (answer_match .group (1 ).strip ())
31+ if (
32+ model_answer is not None
33+ and abs (target_number - model_answer ) < self .tolerance
34+ ):
35+ return 1.0 # Correct answer
5936
60- if abs (expected_answer - model_answer ) < self .tolerance :
61- return 1.0 # Correct answer
62- return 0.0 # Incorrect answer
37+ # Check for partial credit: target number appears elsewhere in response
38+ response_without_answer_tags = re .sub (
39+ r"<answer>.*?</answer>" , "" , response , flags = re .DOTALL
40+ )
41+ # Convert to int if it's a whole number to avoid "117.0" vs "117" mismatch
42+ target_str = (
43+ str (int (target_number ))
44+ if target_number .is_integer ()
45+ else str (target_number )
46+ )
47+ if target_str in response_without_answer_tags :
48+ return self .partial_credit
49+
50+ return 0.0 # No match
51+
52+ def _to_float (self , text : str ) -> float | None :
53+ """Convert text to float, return None if invalid."""
54+ try :
55+ # Remove common non-numeric characters like $, commas, etc.
56+ cleaned_text = re .sub (r"[$,\s]" , "" , text .strip ())
57+ return float (cleaned_text )
58+ except (ValueError , AttributeError ):
59+ return None
6360
6461
6562class ThinkingReward (Reward ):
6663 """Reward class for evaluating use of <think> tags in reasoning."""
6764
68- def __init__ (self , reward_value : float = 0.5 ):
69- self .reward_value = reward_value
65+ def __init__ (self , partial_reward : float = 0.2 , full_reward : float = 1.0 ):
66+ self .partial_reward = partial_reward
67+ self .full_reward = full_reward
68+ self ._THINK_BLOCK_RE = re .compile (
69+ r"<\s*think\s*>(.*?)<\s*/\s*think\s*>" , re .IGNORECASE | re .DOTALL
70+ )
71+ self ._THINK_TAG_ATTEMPT_RE = re .compile (r"<\s*/?\s*think\s*>" , re .IGNORECASE )
7072
71- def __call__ (
72- self , prompt : str , response : str , target : Optional [str ] = None
73- ) -> float :
74- """Check if response contains <think>...</think> tags."""
75- resp = response .lower ()
76- if "<think>" in resp and "</think>" in resp :
77- return self .reward_value
73+ def __call__ (self , prompt : str , response : str , target : str | None = None ) -> float :
74+ matches = self ._THINK_BLOCK_RE .findall (response or "" )
75+ has_well_formed = any (len (re .sub (r"\s+" , "" , m )) >= 1 for m in matches )
76+ has_attempt = bool (self ._THINK_TAG_ATTEMPT_RE .search (response or "" )) or bool (
77+ matches
78+ )
79+ if has_well_formed :
80+ return self .full_reward
81+ elif has_attempt :
82+ return self .partial_reward
7883 return 0.0
0 commit comments