11import torch
2+ from latex2sympy2_extended import NormalizationConfig
3+ from math_verify import ExprExtractionConfig , LatexExtractionConfig , parse , verify
24
35from .reward_utils import extract_boxed_solution , extract_solution , validate_response_structure
46
7+ CANNOT_PARSE_GT_ANSWER = - 1
8+ CANNOT_PARSE_PREDICTION = - 2
9+ SUCCESS = 1
10+ MATCHING_FAIL = 0
11+
12+
13+ def verify_math_representation (completion , gt_answer ):
14+ """
15+ Verify if the completion is a valid math representation of the gt_answer.
16+ """
17+ if not completion .startswith ("\\ boxed{" ):
18+ completion = "\\ boxed{" + completion + "}"
19+ if not gt_answer .startswith ("\\ boxed{" ):
20+ gt_answer = "\\ boxed{" + gt_answer + "}"
21+ target = (
22+ ExprExtractionConfig (),
23+ LatexExtractionConfig (
24+ normalization_config = NormalizationConfig (
25+ nits = False ,
26+ malformed_operators = False ,
27+ basic_latex = True ,
28+ boxed = "all" ,
29+ units = True ,
30+ ),
31+ boxed_match_priority = 0 ,
32+ ),
33+ )
34+ if not isinstance (gt_answer , str ) or len (gt_answer ) == 0 :
35+ raise ValueError ("gt_answer should be a string, please verify your training data." )
36+ if not isinstance (completion , str ) or len (completion ) == 0 :
37+ return MATCHING_FAIL
38+ try :
39+ parsed_gt_answer = parse (gt_answer , extraction_config = target )
40+ if len (parsed_gt_answer ) == 0 :
41+ return CANNOT_PARSE_GT_ANSWER
42+ parsed_completion = parse (completion , extraction_config = target )
43+ if len (parsed_completion ) == 0 :
44+ return CANNOT_PARSE_PREDICTION
45+ if verify (parsed_gt_answer , parsed_completion ):
46+ return SUCCESS
47+ else :
48+ return MATCHING_FAIL
49+ except Exception :
50+ return MATCHING_FAIL
51+
52+
53+ def verify_model_answer (decoded_final_answer , gt_answer , ans_acc , acc_score , reward ):
54+ math_verify_result = verify_math_representation (decoded_final_answer , gt_answer )
55+ exact_match_result = (
56+ SUCCESS
57+ if decoded_final_answer .strip ().replace (" " , "" ).replace ("{" , "" ).replace ("}" , "" ).replace ("," , "" )
58+ == gt_answer .strip ().replace (" " , "" ).replace ("{" , "" ).replace ("}" , "" ).replace ("," , "" )
59+ else MATCHING_FAIL
60+ )
61+ if math_verify_result == SUCCESS :
62+ ans_acc += 1
63+ reward += acc_score
64+ elif exact_match_result == SUCCESS :
65+ # sometimes for answers that's not a (valid) math expression, math_verify will fail
66+ ans_acc += 1
67+ if math_verify_result == CANNOT_PARSE_PREDICTION :
68+ reward += (
69+ acc_score / 2
70+ ) # not a valid latex math representation, but the answer is correct, receive half of the score
71+ else :
72+ reward += acc_score
73+ return reward , ans_acc
74+
575
676def math_reward_fn (input_ids , gt_answer , response_idx , ** kwargs ):
777 tokenizer = kwargs ["tokenizer" ]
@@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
1484 s , e = response_idx [0 ], response_idx [1 ]
1585
1686 length_reward = 0.0
17- if soft_over_length_punishment :
18- max_length = kwargs .get ("max_length" , 1024 * 4 )
19- cache_length = kwargs .get ("cache_length" , 512 )
20- res_length = e .item () - s .item () + 1
21- if max_length - cache_length < res_length < max_length :
22- length_reward = ((max_length - cache_length ) - res_length ) / cache_length * acc_score
87+ res_length = e .item () - s .item () + 1
88+ if not eval_mode :
89+ max_new_tokens = kwargs ["max_new_tokens" ]
90+ else :
91+ max_new_tokens = - 1 # for eval mode, we don't need to check the length
92+ if not eval_mode and soft_over_length_punishment :
93+ cache_length = kwargs ["cache_length" ]
94+ if max_new_tokens - cache_length < res_length < max_new_tokens :
95+ length_reward = ((max_new_tokens - cache_length ) - res_length ) / cache_length * acc_score
2396
2497 if gt_answer is None :
25- return reward
98+ raise ValueError ( "no gt_answer is provided, please check your training dataset." )
2699
27100 decoded_final_answer = tokenizer .decode (input_ids [s : e + 1 ], skip_special_tokens = True )
28101 gt_answer = tokenizer .decode (gt_answer .squeeze (0 ), skip_special_tokens = True )
@@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
35108 format_acc += 1
36109
37110 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
38- if (
39- format_valid
40- and final_answer is not None
41- and gt_answer .strip ().replace (" " , "" ).lower () == final_answer .strip ().replace (" " , "" ).lower ()
42- ):
43- ans_acc += 1
44- reward += acc_score
111+ if final_answer is not None :
112+ if eval_mode or format_valid :
113+ reward , ans_acc = verify_model_answer (final_answer , gt_answer , ans_acc , acc_score , reward )
114+ if not eval_mode :
115+ reward = reward + length_reward
45116
46- reward = reward + length_reward
117+ # Check if the sequence is over length
118+ if not eval_mode and res_length >= max_new_tokens :
119+ reward *= 0.0
47120
48121 if not eval_mode :
49122 return torch .tensor ([reward , format_acc , ans_acc ]).to (input_ids .device )
@@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
56129 "parsed" : final_answer ,
57130 "format_valid" : format_acc .item (),
58131 "ans_valid" : ans_acc .item (),
132+ "response_length" : res_length ,
133+ "reward" : reward .item (),
59134 }
60135
61136
@@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
71146 s , e = response_idx [0 ], response_idx [1 ]
72147
73148 length_reward = 0.0
74- if soft_over_length_punishment :
75- max_length = kwargs .get ("max_length" , 1024 * 4 )
76- cache_length = kwargs .get ("cache_length" , 512 )
77- res_length = e .item () - s .item () + 1
78- if max_length - cache_length < res_length < max_length :
79- length_reward = ((max_length - cache_length ) - res_length ) / cache_length * acc_score
149+ res_length = e .item () - s .item () + 1
150+ if not eval_mode :
151+ max_new_tokens = kwargs ["max_new_tokens" ]
152+ else :
153+ max_new_tokens = - 1 # for eval mode, we don't need to check the length
154+ if not eval_mode and soft_over_length_punishment :
155+ cache_length = kwargs ["cache_length" ]
156+ if max_new_tokens - cache_length < res_length < max_new_tokens :
157+ length_reward = ((max_new_tokens - cache_length ) - res_length ) / cache_length * acc_score
80158
81159 if gt_answer is None :
82- return torch . tensor ([ reward , format_acc , ans_acc ]). to ( input_ids . device )
160+ raise ValueError ( "no gt_answer is provided, please check your training dataset." )
83161
84162 decoded_final_answer = tokenizer .decode (input_ids [s : e + 1 ], skip_special_tokens = True )
163+
85164 gt_answer = tokenizer .decode (gt_answer .squeeze (0 ), skip_special_tokens = True )
86165 final_answer = extract_boxed_solution (decoded_final_answer )
87166 format_valid = final_answer is not None
167+ if "tags" in kwargs and kwargs ["tags" ]:
168+ tags = kwargs ["tags" ]
169+ format_valid = format_valid and all (
170+ [decoded_final_answer .count (tags [tag ]["text" ]) == tags [tag ]["num_occur" ] for tag in tags ]
171+ )
88172 # Check format accuracy
89173 if format_valid :
90174 format_acc += 1
91175 reward += format_score
92176
93177 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
94- if format_valid and final_answer is not None and gt_answer .strip ().lower () == final_answer .strip ().lower ():
95- ans_acc += 1
96- reward += acc_score
178+ if final_answer is not None :
179+ if eval_mode or format_valid :
180+ reward , ans_acc = verify_model_answer (final_answer , gt_answer , ans_acc , acc_score , reward )
181+ if not eval_mode :
182+ reward = reward + length_reward
183+
184+ # Check if the sequence is over length
185+ if not eval_mode and res_length >= max_new_tokens :
186+ reward *= 0.0
97187
98- reward = reward + length_reward
99188 if not eval_mode :
100189 return torch .tensor ([reward , format_acc , ans_acc ]).to (input_ids .device )
101190 else :
@@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
107196 "parsed" : final_answer ,
108197 "format_valid" : format_acc .item (),
109198 "ans_valid" : ans_acc .item (),
199+ "response_length" : res_length ,
200+ "reward" : reward .item (),
110201 }
0 commit comments