1515import io
1616import logging
1717import re
18- from typing import Any , Optional , TypedDict
18+ from typing import Any , Optional , TypedDict , Union
1919
2020import ray
2121import torch
22+ from math_verify import grader
2223from math_verify .errors import TimeoutException
2324from math_verify .metric import math_metric
2425from math_verify .parser import ExprExtractionConfig , LatexExtractionConfig
@@ -69,53 +70,84 @@ def __init__(self) -> None:
6970 )
7071
7172 def verify (
72- self , pred_responses : list [str ], ground_truths : list [str ]
73- ) -> list [float ]:
73+ self ,
74+ pred_responses : list [str ],
75+ ground_truths : list [str ],
76+ return_extracted_answer : bool = False ,
77+ ) -> Union [list [float ], tuple [list [float ], list [str | None ]]]:
7478 """Verify the correctness of the predicted responses against the ground truth.
7579
7680 Args:
7781 pred_responses: list[str]. The predicted responses from the LLM.
7882 ground_truths: list[str]. The ground truth responses.
7983
8084 Returns:
81- list[float]. The rewards for each predicted response.
85+ Union[list[float], tuple[list[float], list[str | None]]].
86+ If return_extracted_answer is False, returns only the scores.
87+ If return_extracted_answer is True, returns (scores, extracted_answers).
8288 """
8389 results = []
90+ extracted_answers : list [str | None ] = []
91+
8492 for response , ground_truth in zip (pred_responses , ground_truths ):
8593 try :
8694 ground_truth_parsable = "\\ boxed{" + ground_truth + "}"
8795 with _mute_output ():
88- try :
89- ret_score , _ = self .verify_func (
90- [ground_truth_parsable ], [response ]
91- )
92- # It's possible to emit a TimeoutException and that wouldn't be caught since
93- # it actually subclasses from BaseException and math-verify itself does not
94- # to catch it.
95- except (Exception , TimeoutException ):
96- ret_score = 0.0
96+ ret_score , extracted_answer = self .verify_func (
97+ [ground_truth_parsable ], [response ]
98+ )
9799
98100 results .append (float (ret_score ))
99- except Exception :
101+
102+ if return_extracted_answer :
103+ # Make sure the extracted answer is not None and is a list of two elements
104+ assert extracted_answer is not None
105+ assert len (extracted_answer ) == 2
106+ extracted_gold , extracted_prediction = extracted_answer
107+ # Get the extracted answer with the same logic as in the HFVerifyWorker
108+ for pred in extracted_prediction :
109+ if any (grader .verify (gold , pred ) for gold in extracted_gold ):
110+ extracted_answers .append (pred )
111+ break
112+ else :
113+ # If no match is found, means all answers are incorrect, just use the first prediction
114+ extracted_answers .append (extracted_prediction [0 ][0 ])
115+
116+ # It's possible to emit a TimeoutException and that wouldn't be caught since
117+ # it actually subclasses from BaseException and math-verify itself does not
118+ # to catch it.
119+ except (Exception , TimeoutException ):
100120 results .append (0.0 )
101- return results
121+ extracted_answers .append (None )
122+
123+ if return_extracted_answer :
124+ return results , extracted_answers
125+ else :
126+ return results
102127
103128
104129@ray .remote # pragma: no cover
105130class MultilingualMultichoiceVerifyWorker :
106131 def verify (
107- self , pred_responses : list [str ], ground_truths : list [str ]
108- ) -> list [float ]:
132+ self ,
133+ pred_responses : list [str ],
134+ ground_truths : list [str ],
135+ return_extracted_answer : bool = False ,
136+ ) -> Union [list [float ], tuple [list [float ], list [str | None ]]]:
109137 """Verify the correctness of the predicted responses against the ground truth.
110138
111139 Args:
112140 pred_responses: list[str]. The predicted responses from the LLM.
113141 ground_truths: list[str]. The ground truth responses.
114142
115143 Returns:
116- list[float]. The rewards for each predicted response.
144+ Union[list[float], tuple[list[float], list[str | None]]].
145+ If return_extracted_answer is False, returns only the scores.
146+ If return_extracted_answer is True, returns (scores, extracted_answers).
117147 """
118148 results = []
149+ extracted_answers : list [str | None ] = []
150+
119151 for response , ground_truth in zip (pred_responses , ground_truths ):
120152 response = answer_parsing .normalize_response (response )
121153 extracted_answer = None
@@ -131,24 +163,36 @@ def verify(
131163 break
132164 score = 1.0 if extracted_answer == ground_truth else 0.0
133165 results .append (score )
134- return results
166+ extracted_answers .append (extracted_answer )
167+
168+ if return_extracted_answer :
169+ return results , extracted_answers
170+ else :
171+ return results
135172
136173
137174@ray .remote # pragma: no cover
138175class EnglishMultichoiceVerifyWorker :
139176 def verify (
140- self , pred_responses : list [str ], ground_truths : list [str ]
141- ) -> list [float ]:
177+ self ,
178+ pred_responses : list [str ],
179+ ground_truths : list [str ],
180+ return_extracted_answer : bool = False ,
181+ ) -> Union [list [float ], tuple [list [float ], list [str | None ]]]:
142182 """Verify the correctness of the predicted responses against the ground truth.
143183
144184 Args:
145185 pred_responses: list[str]. The predicted responses from the LLM.
146186 ground_truths: list[str]. The ground truth responses.
147187
148188 Returns:
149- list[float]. The rewards for each predicted response.
189+ Union[list[float], tuple[list[float], list[str | None]]].
190+ If return_extracted_answer is False, returns only the scores.
191+ If return_extracted_answer is True, returns (scores, extracted_answers).
150192 """
151193 results = []
194+ extracted_answers : list [str | None ] = []
195+
152196 for response , ground_truth in zip (pred_responses , ground_truths ):
153197 ground_truth = answer_parsing .normalize_response (ground_truth )
154198 response = answer_parsing .normalize_response (response )
@@ -160,11 +204,18 @@ def verify(
160204 )
161205 score = 1.0 if extracted_answer == ground_truth else 0.0
162206 results .append (score )
163- return results
207+ if return_extracted_answer :
208+ extracted_answers .append (extracted_answer )
209+
210+ if return_extracted_answer :
211+ return results , extracted_answers
212+ else :
213+ return results
164214
165215
166216class MathEnvironmentMetadata (TypedDict ):
167217 ground_truth : str
218+ extracted_answer : str | None
168219
169220
170221@ray .remote (max_restarts = - 1 , max_task_retries = - 1 ) # pragma: no cover
@@ -198,12 +249,13 @@ def step(
198249 self ,
199250 message_log_batch : list [LLMMessageLogType ],
200251 metadata : list [MathEnvironmentMetadata ],
252+ return_extracted_answer : bool = False ,
201253 ) -> EnvironmentReturn [MathEnvironmentMetadata ]:
202254 """Runs a step in the math environment.
203255
204256 Args:
205257 message_log: list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM.
206- metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness.
258+ metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k.
207259
208260 Returns:
209261 EnvironmentReturn: A tuple containing:
@@ -231,18 +283,32 @@ def step(
231283 )
232284 chunked_ground_truths = chunk_list_to_workers (ground_truths , self .num_workers )
233285
234- # # Process each chunk in parallel
286+ # Process each chunk in parallel
235287 futures = [
236- self .workers [i ].verify .remote (chunk , ground_truth_chunk )
288+ self .workers [i ].verify .remote (
289+ chunk , ground_truth_chunk , return_extracted_answer
290+ )
237291 for i , (chunk , ground_truth_chunk ) in enumerate (
238292 zip (chunked_assistant_response_batch , chunked_ground_truths )
239293 )
240294 ]
241295
242- results = ray .get (futures )
296+ worker_results = ray .get (futures )
297+
298+ # Flatten the results and extract both scores and answers
299+ results = []
300+ extracted_answers : list [str | None ] | None = (
301+ [] if return_extracted_answer else None
302+ )
303+
304+ for worker_result in worker_results :
305+ if return_extracted_answer :
306+ worker_scores , worker_answers = worker_result
307+ results .extend (worker_scores )
308+ extracted_answers .extend (worker_answers )
309+ else :
310+ results .extend (worker_result )
243311
244- # flatten the results
245- results = [item for sublist in results for item in sublist ]
246312 observations = [
247313 {
248314 "role" : "environment" ,
@@ -256,7 +322,6 @@ def step(
256322 # create a tensor of rewards and done flags
257323 rewards = torch .tensor (results ).cpu ()
258324 done = torch .ones_like (rewards ).cpu ()
259-
260325 next_stop_strings = [None ] * len (message_log_batch )
261326
262327 return EnvironmentReturn (
@@ -265,6 +330,7 @@ def step(
265330 next_stop_strings = next_stop_strings ,
266331 rewards = rewards ,
267332 terminateds = done ,
333+ answers = extracted_answers ,
268334 )
269335
270336 def global_post_process_and_metrics (
0 commit comments