Skip to content

Commit 6f114b9

Browse files
committed
Update eval_aime_benchmark.py
- handle multiple responses as multiple attempts
1 parent 5f669ab commit 6f114b9

File tree

1 file changed

+44
-9
lines changed

1 file changed

+44
-9
lines changed

scripts/eval_aime_benchmark.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import re
66
import time
7-
from typing import List, Dict, Tuple, Optional
7+
from typing import List, Dict, Tuple, Optional, Union
88
from datetime import datetime
99
from openai import OpenAI
1010
from datasets import load_dataset
@@ -89,9 +89,17 @@ def extract_answer(response: str) -> Optional[int]:
8989

9090
return None
9191

92-
def get_llm_response(problem: str, model: str) -> str:
92+
def get_llm_response(problem: str, model: str) -> Union[str, List[Dict]]:
9393
"""
9494
Get response from the LLM for a given problem.
95+
If multiple choices are returned, formats them as attempt dictionaries.
96+
97+
Args:
98+
problem (str): The problem text
99+
model (str): The model identifier
100+
101+
Returns:
102+
Union[str, List[Dict]]: Either a string response or list of attempt dictionaries
95103
"""
96104
try:
97105
response = client.with_options(timeout=1000.0).chat.completions.create(
@@ -101,7 +109,23 @@ def get_llm_response(problem: str, model: str) -> str:
101109
],
102110
max_tokens=8192,
103111
)
112+
113+
# If there's more than one choice, format as attempts
114+
if len(response.choices) > 1:
115+
attempts = []
116+
for i, choice in enumerate(response.choices):
117+
response_text = choice.message.content.strip()
118+
predicted_answer = extract_answer(response_text)
119+
attempts.append({
120+
"attempt_number": i + 1,
121+
"response": response_text,
122+
"predicted_answer": predicted_answer
123+
})
124+
return attempts
125+
126+
# If single choice, return as before
104127
return response.choices[0].message.content.strip()
128+
105129
except Exception as e:
106130
logger.error(f"Error getting LLM response: {e}")
107131
return ""
@@ -119,14 +143,25 @@ def make_n_attempts(problem: str, model: str, n: int) -> List[Dict]:
119143
List[Dict]: List of dictionaries containing response and predicted answer for each attempt
120144
"""
121145
attempts = []
122-
for i in range(n):
146+
remaining_attempts = n
147+
148+
while remaining_attempts > 0:
123149
response = get_llm_response(problem, model)
124-
predicted_answer = extract_answer(response)
125-
attempts.append({
126-
"attempt_number": i + 1,
127-
"response": response,
128-
"predicted_answer": predicted_answer
129-
})
150+
151+
# If response is already formatted as attempts
152+
if isinstance(response, list):
153+
attempts.extend(response[:remaining_attempts]) # Only take what we need
154+
remaining_attempts -= len(response)
155+
else:
156+
# Process single response as before
157+
predicted_answer = extract_answer(response)
158+
attempts.append({
159+
"attempt_number": len(attempts) + 1,
160+
"response": response,
161+
"predicted_answer": predicted_answer
162+
})
163+
remaining_attempts -= 1
164+
130165
return attempts
131166

132167
def evaluate_pass_at_n(attempts: List[Dict], correct_answer: int) -> Tuple[bool, Optional[int]]:

0 commit comments

Comments
 (0)