Skip to content

Commit 024d173

Browse files
authored
feat: implement cons@k evaluation (#640)
Signed-off-by: ruit <ruit@nvidia.com> Signed-off-by: Rui Tian <ruit@cw-dfw-cs-001-login-02.cm.cluster>
1 parent 3395bd8 commit 024d173

File tree

12 files changed

+399
-57
lines changed

12 files changed

+399
-57
lines changed

docs/guides/eval.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ When you complete the evaluation, you will receive a summary similar to the foll
8181
model_name='Qwen2.5-Math-1.5B-Instruct' dataset_name='aime2024'
8282
max_new_tokens=2048 temperature=0.0 top_p=1.0 top_k=-1
8383
84-
metric='pass@k' pass_k_value=1 num_tests_per_prompt=1
84+
metric=pass@1 num_tests_per_prompt=1
8585
8686
score=0.1000 (3.0/30)
8787
============================================================

examples/configs/evals/eval.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Evaluation Configuration
22
eval:
3-
metric: "pass@k"
3+
metric: "pass@k" # pass@k and cons@k are supported
44
num_tests_per_prompt: 1 # every prompt will be tested num_tests_per_prompt times and use the average score as the final score
55
seed: 42
6-
pass_k_value: 1
6+
k_value: 1
77
save_path: null # Path to save evaluation results and configuration of the evaluation. Set to null to disable saving. Example: "results/eval_output" or "/path/to/evaluation_results"
88

99
generation:

nemo_rl/environments/code_environment.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def step(
206206
self,
207207
message_log_batch: List[LLMMessageLogType],
208208
metadata_batch: List[CodeEnvMetadata],
209+
return_extracted_answer: bool = False,
209210
) -> EnvironmentReturn:
210211
"""Process a batch of code execution steps."""
211212
message_batch = [ml[-1]["content"] for ml in message_log_batch]
@@ -240,12 +241,18 @@ def step(
240241

241242
next_stop_strings = [["</code>"]] * len(message_log_batch)
242243

244+
assert return_extracted_answer == False, (
245+
"return_extracted_answer is not supported in CodeEnvironment. Please set it to False."
246+
)
247+
extracted_answers = None
248+
243249
return EnvironmentReturn(
244250
observations=observations,
245251
metadata=new_metadata_batch,
246252
next_stop_strings=next_stop_strings,
247253
rewards=rewards_tensor,
248254
terminateds=terminated_tensor,
255+
answers=extracted_answers,
249256
)
250257

251258
def shutdown(self):

nemo_rl/environments/games/sliding_puzzle.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def process_turn(
272272
bool,
273273
Optional[list[str]],
274274
Optional[SlidingPuzzleMetadata],
275+
Optional[list[str]],
275276
]:
276277
"""Processes a single turn for the sliding puzzle task."""
277278
game_state = metadata["game_state"]
@@ -297,6 +298,7 @@ def process_turn(
297298
is_terminated,
298299
None,
299300
next_metadata,
301+
None,
300302
)
301303

302304
# Get last assistant message and parse action
@@ -328,13 +330,15 @@ def process_turn(
328330

329331
if is_terminated:
330332
next_metadata = None # Clear metadata on termination
331-
333+
# answers save the extracted answer, only assigned in the verify function
334+
next_answers = None
332335
return (
333336
{"role": "environment", "content": next_observation_content + "\n"},
334337
turn_reward,
335338
is_terminated,
336339
next_stop_strings,
337340
next_metadata,
341+
next_answers,
338342
)
339343

340344

@@ -365,13 +369,15 @@ def step(
365369
terminateds = []
366370
all_stop_strings = []
367371
all_next_metadata = []
372+
all_answers = []
368373

369-
for obs, rew, term, stops, meta in results:
374+
for obs, rew, term, stops, meta, answ in results:
370375
observations.append(obs)
371376
rewards.append(rew)
372377
terminateds.append(term)
373378
all_stop_strings.append(stops)
374379
all_next_metadata.append(meta)
380+
all_answers.append(answ)
375381

376382
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
377383
terminated_tensor = torch.tensor(terminateds, dtype=torch.bool)
@@ -382,6 +388,7 @@ def step(
382388
next_stop_strings=all_stop_strings,
383389
rewards=rewards_tensor,
384390
terminateds=terminated_tensor,
391+
answers=all_answers,
385392
)
386393

387394
def shutdown(self):

nemo_rl/environments/interfaces.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]):
3838
similar. This field lets you control this per turn.
3939
rewards: the rewards for this turn.
4040
terminateds: whether the episode ended this turn.
41+
answers: the answers for this turn.
4142
"""
4243

4344
observations: list[dict[str, str]]
4445
metadata: list[MetadataT]
4546
next_stop_strings: list[list[str] | None] | list[None]
4647
rewards: Tensor
4748
terminateds: Tensor
49+
answers: list[str | None] | None
4850

4951

5052
class EnvironmentInterface(abc.ABC, Generic[MetadataT]):

nemo_rl/environments/math_environment.py

Lines changed: 96 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import io
1616
import logging
1717
import re
18-
from typing import Any, Optional, TypedDict
18+
from typing import Any, Optional, TypedDict, Union
1919

2020
import ray
2121
import torch
22+
from math_verify import grader
2223
from math_verify.errors import TimeoutException
2324
from math_verify.metric import math_metric
2425
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
@@ -69,53 +70,84 @@ def __init__(self) -> None:
6970
)
7071

7172
def verify(
72-
self, pred_responses: list[str], ground_truths: list[str]
73-
) -> list[float]:
73+
self,
74+
pred_responses: list[str],
75+
ground_truths: list[str],
76+
return_extracted_answer: bool = False,
77+
) -> Union[list[float], tuple[list[float], list[str | None]]]:
7478
"""Verify the correctness of the predicted responses against the ground truth.
7579
7680
Args:
7781
pred_responses: list[str]. The predicted responses from the LLM.
7882
ground_truths: list[str]. The ground truth responses.
7983
8084
Returns:
81-
list[float]. The rewards for each predicted response.
85+
Union[list[float], tuple[list[float], list[str | None]]].
86+
If return_extracted_answer is False, returns only the scores.
87+
If return_extracted_answer is True, returns (scores, extracted_answers).
8288
"""
8389
results = []
90+
extracted_answers: list[str | None] = []
91+
8492
for response, ground_truth in zip(pred_responses, ground_truths):
8593
try:
8694
ground_truth_parsable = "\\boxed{" + ground_truth + "}"
8795
with _mute_output():
88-
try:
89-
ret_score, _ = self.verify_func(
90-
[ground_truth_parsable], [response]
91-
)
92-
# It's possible to emit a TimeoutException and that wouldn't be caught since
93-
# it actually subclasses from BaseException and math-verify itself does not
94-
# to catch it.
95-
except (Exception, TimeoutException):
96-
ret_score = 0.0
96+
ret_score, extracted_answer = self.verify_func(
97+
[ground_truth_parsable], [response]
98+
)
9799

98100
results.append(float(ret_score))
99-
except Exception:
101+
102+
if return_extracted_answer:
103+
# Make sure the extracted answer is not None and is a list of two elements
104+
assert extracted_answer is not None
105+
assert len(extracted_answer) == 2
106+
extracted_gold, extracted_prediction = extracted_answer
107+
# Get the extracted answer with the same logic as in the HFVerifyWorker
108+
for pred in extracted_prediction:
109+
if any(grader.verify(gold, pred) for gold in extracted_gold):
110+
extracted_answers.append(pred)
111+
break
112+
else:
113+
# If no match is found, means all answers are incorrect, just use the first prediction
114+
extracted_answers.append(extracted_prediction[0][0])
115+
116+
# It's possible to emit a TimeoutException and that wouldn't be caught since
117+
# it actually subclasses from BaseException and math-verify itself does not
118+
# to catch it.
119+
except (Exception, TimeoutException):
100120
results.append(0.0)
101-
return results
121+
extracted_answers.append(None)
122+
123+
if return_extracted_answer:
124+
return results, extracted_answers
125+
else:
126+
return results
102127

103128

104129
@ray.remote # pragma: no cover
105130
class MultilingualMultichoiceVerifyWorker:
106131
def verify(
107-
self, pred_responses: list[str], ground_truths: list[str]
108-
) -> list[float]:
132+
self,
133+
pred_responses: list[str],
134+
ground_truths: list[str],
135+
return_extracted_answer: bool = False,
136+
) -> Union[list[float], tuple[list[float], list[str | None]]]:
109137
"""Verify the correctness of the predicted responses against the ground truth.
110138
111139
Args:
112140
pred_responses: list[str]. The predicted responses from the LLM.
113141
ground_truths: list[str]. The ground truth responses.
114142
115143
Returns:
116-
list[float]. The rewards for each predicted response.
144+
Union[list[float], tuple[list[float], list[str | None]]].
145+
If return_extracted_answer is False, returns only the scores.
146+
If return_extracted_answer is True, returns (scores, extracted_answers).
117147
"""
118148
results = []
149+
extracted_answers: list[str | None] = []
150+
119151
for response, ground_truth in zip(pred_responses, ground_truths):
120152
response = answer_parsing.normalize_response(response)
121153
extracted_answer = None
@@ -131,24 +163,36 @@ def verify(
131163
break
132164
score = 1.0 if extracted_answer == ground_truth else 0.0
133165
results.append(score)
134-
return results
166+
extracted_answers.append(extracted_answer)
167+
168+
if return_extracted_answer:
169+
return results, extracted_answers
170+
else:
171+
return results
135172

136173

137174
@ray.remote # pragma: no cover
138175
class EnglishMultichoiceVerifyWorker:
139176
def verify(
140-
self, pred_responses: list[str], ground_truths: list[str]
141-
) -> list[float]:
177+
self,
178+
pred_responses: list[str],
179+
ground_truths: list[str],
180+
return_extracted_answer: bool = False,
181+
) -> Union[list[float], tuple[list[float], list[str | None]]]:
142182
"""Verify the correctness of the predicted responses against the ground truth.
143183
144184
Args:
145185
pred_responses: list[str]. The predicted responses from the LLM.
146186
ground_truths: list[str]. The ground truth responses.
147187
148188
Returns:
149-
list[float]. The rewards for each predicted response.
189+
Union[list[float], tuple[list[float], list[str | None]]].
190+
If return_extracted_answer is False, returns only the scores.
191+
If return_extracted_answer is True, returns (scores, extracted_answers).
150192
"""
151193
results = []
194+
extracted_answers: list[str | None] = []
195+
152196
for response, ground_truth in zip(pred_responses, ground_truths):
153197
ground_truth = answer_parsing.normalize_response(ground_truth)
154198
response = answer_parsing.normalize_response(response)
@@ -160,11 +204,18 @@ def verify(
160204
)
161205
score = 1.0 if extracted_answer == ground_truth else 0.0
162206
results.append(score)
163-
return results
207+
if return_extracted_answer:
208+
extracted_answers.append(extracted_answer)
209+
210+
if return_extracted_answer:
211+
return results, extracted_answers
212+
else:
213+
return results
164214

165215

166216
class MathEnvironmentMetadata(TypedDict):
167217
ground_truth: str
218+
extracted_answer: str | None
168219

169220

170221
@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover
@@ -198,12 +249,13 @@ def step(
198249
self,
199250
message_log_batch: list[LLMMessageLogType],
200251
metadata: list[MathEnvironmentMetadata],
252+
return_extracted_answer: bool = False,
201253
) -> EnvironmentReturn[MathEnvironmentMetadata]:
202254
"""Runs a step in the math environment.
203255
204256
Args:
205257
message_log: list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM.
206-
metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness.
258+
metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k.
207259
208260
Returns:
209261
EnvironmentReturn: A tuple containing:
@@ -231,18 +283,32 @@ def step(
231283
)
232284
chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers)
233285

234-
# # Process each chunk in parallel
286+
# Process each chunk in parallel
235287
futures = [
236-
self.workers[i].verify.remote(chunk, ground_truth_chunk)
288+
self.workers[i].verify.remote(
289+
chunk, ground_truth_chunk, return_extracted_answer
290+
)
237291
for i, (chunk, ground_truth_chunk) in enumerate(
238292
zip(chunked_assistant_response_batch, chunked_ground_truths)
239293
)
240294
]
241295

242-
results = ray.get(futures)
296+
worker_results = ray.get(futures)
297+
298+
# Flatten the results and extract both scores and answers
299+
results = []
300+
extracted_answers: list[str | None] | None = (
301+
[] if return_extracted_answer else None
302+
)
303+
304+
for worker_result in worker_results:
305+
if return_extracted_answer:
306+
worker_scores, worker_answers = worker_result
307+
results.extend(worker_scores)
308+
extracted_answers.extend(worker_answers)
309+
else:
310+
results.extend(worker_result)
243311

244-
# flatten the results
245-
results = [item for sublist in results for item in sublist]
246312
observations = [
247313
{
248314
"role": "environment",
@@ -256,7 +322,6 @@ def step(
256322
# create a tensor of rewards and done flags
257323
rewards = torch.tensor(results).cpu()
258324
done = torch.ones_like(rewards).cpu()
259-
260325
next_stop_strings = [None] * len(message_log_batch)
261326

262327
return EnvironmentReturn(
@@ -265,6 +330,7 @@ def step(
265330
next_stop_strings=next_stop_strings,
266331
rewards=rewards,
267332
terminateds=done,
333+
answers=extracted_answers,
268334
)
269335

270336
def global_post_process_and_metrics(

nemo_rl/environments/tools/retriever.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def step(
162162
self,
163163
message_log_batch: List[LLMMessageLogType],
164164
metadata_batch: List[Dict[str, Any]],
165+
return_extracted_answer: bool = False,
165166
) -> EnvironmentReturn:
166167
"""Process a batch of retrieval steps."""
167168
# Extract queries from the last message in each log
@@ -186,12 +187,18 @@ def step(
186187
terminated_tensor = torch.ones(batch_size, dtype=torch.bool)
187188
next_stop_strings = [["</retrieve>"]] * batch_size
188189

190+
assert return_extracted_answer == False, (
191+
"return_extracted_answer is not supported in RAGEnvironment. Please set it to False."
192+
)
193+
extracted_answers = None
194+
189195
return EnvironmentReturn(
190196
observations=results,
191197
metadata=metadata_batch,
192198
next_stop_strings=next_stop_strings,
193199
rewards=rewards_tensor,
194200
terminateds=terminated_tensor,
201+
answers=extracted_answers,
195202
)
196203

197204
def shutdown(self):

0 commit comments

Comments
 (0)