1
1
import torch
2
+ from latex2sympy2_extended import NormalizationConfig
3
+ from math_verify import ExprExtractionConfig , LatexExtractionConfig , parse , verify
2
4
3
5
from .reward_utils import extract_boxed_solution , extract_solution , validate_response_structure
4
6
7
+ CANNOT_PARSE_GT_ANSWER = - 1
8
+ CANNOT_PARSE_PREDICTION = - 2
9
+ SUCCESS = 1
10
+ MATCHING_FAIL = 0
11
+
12
+
13
+ def verify_math_representation (completion , gt_answer ):
14
+ """
15
+ Verify if the completion is a valid math representation of the gt_answer.
16
+ """
17
+ if not completion .startswith ("\\ boxed{" ):
18
+ completion = "\\ boxed{" + completion + "}"
19
+ if not gt_answer .startswith ("\\ boxed{" ):
20
+ gt_answer = "\\ boxed{" + gt_answer + "}"
21
+ target = (
22
+ ExprExtractionConfig (),
23
+ LatexExtractionConfig (
24
+ normalization_config = NormalizationConfig (
25
+ nits = False ,
26
+ malformed_operators = False ,
27
+ basic_latex = True ,
28
+ boxed = "all" ,
29
+ units = True ,
30
+ ),
31
+ boxed_match_priority = 0 ,
32
+ ),
33
+ )
34
+ if not isinstance (gt_answer , str ) or len (gt_answer ) == 0 :
35
+ raise ValueError ("gt_answer should be a string, please verify your training data." )
36
+ if not isinstance (completion , str ) or len (completion ) == 0 :
37
+ return MATCHING_FAIL
38
+ try :
39
+ parsed_gt_answer = parse (gt_answer , extraction_config = target )
40
+ if len (parsed_gt_answer ) == 0 :
41
+ return CANNOT_PARSE_GT_ANSWER
42
+ parsed_completion = parse (completion , extraction_config = target )
43
+ if len (parsed_completion ) == 0 :
44
+ return CANNOT_PARSE_PREDICTION
45
+ if verify (parsed_gt_answer , parsed_completion ):
46
+ return SUCCESS
47
+ else :
48
+ return MATCHING_FAIL
49
+ except Exception :
50
+ return MATCHING_FAIL
51
+
52
+
53
+ def verify_model_answer (decoded_final_answer , gt_answer , ans_acc , acc_score , reward ):
54
+ math_verify_result = verify_math_representation (decoded_final_answer , gt_answer )
55
+ exact_match_result = (
56
+ SUCCESS
57
+ if decoded_final_answer .strip ().replace (" " , "" ).replace ("{" , "" ).replace ("}" , "" ).replace ("," , "" )
58
+ == gt_answer .strip ().replace (" " , "" ).replace ("{" , "" ).replace ("}" , "" ).replace ("," , "" )
59
+ else MATCHING_FAIL
60
+ )
61
+ if math_verify_result == SUCCESS :
62
+ ans_acc += 1
63
+ reward += acc_score
64
+ elif exact_match_result == SUCCESS :
65
+ # sometimes for answers that's not a (valid) math expression, math_verify will fail
66
+ ans_acc += 1
67
+ if math_verify_result == CANNOT_PARSE_PREDICTION :
68
+ reward += (
69
+ acc_score / 2
70
+ ) # not a valid latex math representation, but the answer is correct, receive half of the score
71
+ else :
72
+ reward += acc_score
73
+ return reward , ans_acc
74
+
5
75
6
76
def math_reward_fn (input_ids , gt_answer , response_idx , ** kwargs ):
7
77
tokenizer = kwargs ["tokenizer" ]
@@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
14
84
s , e = response_idx [0 ], response_idx [1 ]
15
85
16
86
length_reward = 0.0
17
- if soft_over_length_punishment :
18
- max_length = kwargs .get ("max_length" , 1024 * 4 )
19
- cache_length = kwargs .get ("cache_length" , 512 )
20
- res_length = e .item () - s .item () + 1
21
- if max_length - cache_length < res_length < max_length :
22
- length_reward = ((max_length - cache_length ) - res_length ) / cache_length * acc_score
87
+ res_length = e .item () - s .item () + 1
88
+ if not eval_mode :
89
+ max_new_tokens = kwargs ["max_new_tokens" ]
90
+ else :
91
+ max_new_tokens = - 1 # for eval mode, we don't need to check the length
92
+ if not eval_mode and soft_over_length_punishment :
93
+ cache_length = kwargs ["cache_length" ]
94
+ if max_new_tokens - cache_length < res_length < max_new_tokens :
95
+ length_reward = ((max_new_tokens - cache_length ) - res_length ) / cache_length * acc_score
23
96
24
97
if gt_answer is None :
25
- return reward
98
+ raise ValueError ( "no gt_answer is provided, please check your training dataset." )
26
99
27
100
decoded_final_answer = tokenizer .decode (input_ids [s : e + 1 ], skip_special_tokens = True )
28
101
gt_answer = tokenizer .decode (gt_answer .squeeze (0 ), skip_special_tokens = True )
@@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
35
108
format_acc += 1
36
109
37
110
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
38
- if (
39
- format_valid
40
- and final_answer is not None
41
- and gt_answer .strip ().replace (" " , "" ).lower () == final_answer .strip ().replace (" " , "" ).lower ()
42
- ):
43
- ans_acc += 1
44
- reward += acc_score
111
+ if final_answer is not None :
112
+ if eval_mode or format_valid :
113
+ reward , ans_acc = verify_model_answer (final_answer , gt_answer , ans_acc , acc_score , reward )
114
+ if not eval_mode :
115
+ reward = reward + length_reward
45
116
46
- reward = reward + length_reward
117
+ # Check if the sequence is over length
118
+ if not eval_mode and res_length >= max_new_tokens :
119
+ reward *= 0.0
47
120
48
121
if not eval_mode :
49
122
return torch .tensor ([reward , format_acc , ans_acc ]).to (input_ids .device )
@@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
56
129
"parsed" : final_answer ,
57
130
"format_valid" : format_acc .item (),
58
131
"ans_valid" : ans_acc .item (),
132
+ "response_length" : res_length ,
133
+ "reward" : reward .item (),
59
134
}
60
135
61
136
@@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
71
146
s , e = response_idx [0 ], response_idx [1 ]
72
147
73
148
length_reward = 0.0
74
- if soft_over_length_punishment :
75
- max_length = kwargs .get ("max_length" , 1024 * 4 )
76
- cache_length = kwargs .get ("cache_length" , 512 )
77
- res_length = e .item () - s .item () + 1
78
- if max_length - cache_length < res_length < max_length :
79
- length_reward = ((max_length - cache_length ) - res_length ) / cache_length * acc_score
149
+ res_length = e .item () - s .item () + 1
150
+ if not eval_mode :
151
+ max_new_tokens = kwargs ["max_new_tokens" ]
152
+ else :
153
+ max_new_tokens = - 1 # for eval mode, we don't need to check the length
154
+ if not eval_mode and soft_over_length_punishment :
155
+ cache_length = kwargs ["cache_length" ]
156
+ if max_new_tokens - cache_length < res_length < max_new_tokens :
157
+ length_reward = ((max_new_tokens - cache_length ) - res_length ) / cache_length * acc_score
80
158
81
159
if gt_answer is None :
82
- return torch . tensor ([ reward , format_acc , ans_acc ]). to ( input_ids . device )
160
+ raise ValueError ( "no gt_answer is provided, please check your training dataset." )
83
161
84
162
decoded_final_answer = tokenizer .decode (input_ids [s : e + 1 ], skip_special_tokens = True )
163
+
85
164
gt_answer = tokenizer .decode (gt_answer .squeeze (0 ), skip_special_tokens = True )
86
165
final_answer = extract_boxed_solution (decoded_final_answer )
87
166
format_valid = final_answer is not None
167
+ if "tags" in kwargs and kwargs ["tags" ]:
168
+ tags = kwargs ["tags" ]
169
+ format_valid = format_valid and all (
170
+ [decoded_final_answer .count (tags [tag ]["text" ]) == tags [tag ]["num_occur" ] for tag in tags ]
171
+ )
88
172
# Check format accuracy
89
173
if format_valid :
90
174
format_acc += 1
91
175
reward += format_score
92
176
93
177
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
94
- if format_valid and final_answer is not None and gt_answer .strip ().lower () == final_answer .strip ().lower ():
95
- ans_acc += 1
96
- reward += acc_score
178
+ if final_answer is not None :
179
+ if eval_mode or format_valid :
180
+ reward , ans_acc = verify_model_answer (final_answer , gt_answer , ans_acc , acc_score , reward )
181
+ if not eval_mode :
182
+ reward = reward + length_reward
183
+
184
+ # Check if the sequence is over length
185
+ if not eval_mode and res_length >= max_new_tokens :
186
+ reward *= 0.0
97
187
98
- reward = reward + length_reward
99
188
if not eval_mode :
100
189
return torch .tensor ([reward , format_acc , ans_acc ]).to (input_ids .device )
101
190
else :
@@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
107
196
"parsed" : final_answer ,
108
197
"format_valid" : format_acc .item (),
109
198
"ans_valid" : ans_acc .item (),
199
+ "response_length" : res_length ,
200
+ "reward" : reward .item (),
110
201
}
0 commit comments