Skip to content

Commit 0e54243

Browse files
committed
Add eval script for AIME 2024
1 parent 97eb708 commit 0e54243

File tree

3 files changed

+249
-3
lines changed

3 files changed

+249
-3
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ response = client.chat.completions.create(
191191
- e.g. for llama.cpp, run `python3 optillm.py --base_url http://localhost:8080/v1`
192192

193193
> [!WARNING]
194-
> Note that llama-server (and ollama) currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
195-
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`. Use the built-in local inference server to use these approaches.
194+
> Note that the Anthropic API, llama-server (and ollama) currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
195+
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`. For models on HuggingFace, you can use the built-in local inference server as it supports multiple responses.
196196
197197
## Implemented techniques
198198

optillm/plugins/executecode_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88

99
SLUG = "executecode"
1010

11+
EXECUTE_CODE_PROMPT = '''Generate Python code to solve this problem. Put the code in a ```python block. The code:
12+
1. Should use standard Python libraries (math, itertools, etc.)
13+
2. Should print the final answer
14+
3. Should be complete and runnable
15+
4. Should include example test cases if relevant
16+
17+
The code will be automatically executed when submitted.'''
18+
1119
def extract_python_code(text: str) -> List[str]:
1220
"""Extract Python code blocks from text."""
1321
# print(f"Extracting code: {text}")
@@ -78,7 +86,7 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
7886
else:
7987
# Get initial response from the model
8088
messages = [
81-
{"role": "system", "content": system_prompt},
89+
{"role": "system", "content": system_prompt + EXECUTE_CODE_PROMPT} ,
8290
{"role": "user", "content": initial_query}
8391
]
8492

scripts/eval_aime_benchmark.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import argparse
2+
import json
3+
import os
4+
import logging
5+
import re
6+
from typing import List, Dict, Tuple, Optional
7+
from datetime import datetime
8+
9+
from openai import OpenAI
10+
from datasets import load_dataset
11+
from tqdm import tqdm
12+
13+
# Configure logging
14+
logging.basicConfig(level=logging.INFO)
15+
logger = logging.getLogger(__name__)
16+
17+
# Initialize OpenAI client
18+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:8000/v1")
19+
20+
SYSTEM_PROMPT = '''You are solving AIME (American Invitational Mathematics Examination) problems.
21+
22+
Important: Always end your solution with the final answer in one of these two formats:
23+
24+
1. \\[
25+
\\boxed{X}.
26+
\\]
27+
28+
2. $n=\\boxed{X}$
29+
30+
where X is your integer answer between 0 and 999.'''
31+
32+
def load_2024_dataset() -> list[dict]:
33+
"""
34+
Load the dataset of problems.
35+
Returns:
36+
list[dict]: The dataset of problems.
37+
"""
38+
dataset_original = load_dataset("AI-MO/aimo-validation-aime")
39+
# Filter out problems that are not from 2024
40+
dataset = dataset_original["train"].filter(lambda example: "2024" in example["url"])
41+
logging.debug(f"Filtered dataset size: {len(dataset)}.")
42+
assert len(dataset) == 30, f"Expected 30 problems after filtering by 2024, but found {len(dataset)}"
43+
return dataset
44+
45+
def extract_answer(response: str) -> Optional[int]:
46+
"""
47+
Extract the numerical answer from a math solution response.
48+
Handles various formats of boxed answers and falls back to last number if needed.
49+
50+
Args:
51+
response (str): The complete response text from the model
52+
53+
Returns:
54+
Optional[int]: The extracted answer as an integer, or None if no valid answer found
55+
"""
56+
if not response:
57+
return None
58+
59+
# Clean the response: normalize whitespace and handle potential Unicode
60+
response = ' '.join(response.split())
61+
62+
# List of regex patterns to try, in order of preference
63+
patterns = [
64+
# $n=\boxed{X}$ format
65+
r'\$n=\\boxed{(\d+)}\$',
66+
67+
# LaTeX display style answer: \[\boxed{X}\] or \[\boxed{X}.\]
68+
r'\\\[\\boxed{(\d+)}\\\]',
69+
r'\\\[\\boxed{(\d+)}\.\\\]',
70+
71+
# Inline LaTeX \boxed{X}
72+
r'\\boxed{(\d+)}',
73+
74+
# Common variations
75+
r'\$\\boxed{(\d+)}\$',
76+
r'boxed{(\d+)}',
77+
78+
# Less strict patterns
79+
r'\\boxed\s*{\s*(\d+)\s*}',
80+
r'\bboxed\s*{\s*(\d+)\s*}',
81+
82+
# Plain text answer indicators
83+
r'final answer is[^\d]*(\d+)',
84+
r'answer is[^\d]*(\d+)',
85+
r'answer:[^\d]*(\d+)',
86+
r'= ?(\d+)$'
87+
]
88+
89+
# Try each pattern in order
90+
for pattern in patterns:
91+
matches = re.finditer(pattern, response, re.IGNORECASE)
92+
# Get the last match for this pattern (in case there are multiple)
93+
last_match = None
94+
for match in matches:
95+
last_match = match
96+
97+
if last_match:
98+
try:
99+
return int(last_match.group(1))
100+
except (ValueError, IndexError):
101+
continue
102+
103+
# Fallback: Extract all numbers and take the last one
104+
# This is our last resort, assuming the answer typically comes last
105+
numbers = re.findall(r'(\d+)', response)
106+
if numbers:
107+
try:
108+
# Convert to int and return the last number found
109+
return int(numbers[-1])
110+
except ValueError:
111+
pass
112+
113+
# If all methods fail, return None
114+
return None
115+
116+
def get_llm_response(problem: str, model: str) -> str:
117+
"""
118+
Get response from the LLM for a given problem.
119+
"""
120+
try:
121+
response = client.chat.completions.create(
122+
model=model,
123+
messages=[
124+
{"role": "system", "content": SYSTEM_PROMPT},
125+
{"role": "user", "content": problem}
126+
],
127+
max_tokens=8192,
128+
# extra_body={
129+
# "decoding": "entropy_decoding",
130+
# }
131+
)
132+
return response.choices[0].message.content.strip()
133+
except Exception as e:
134+
logger.error(f"Error getting LLM response: {e}")
135+
return ""
136+
137+
def evaluate_response(predicted_answer: Optional[int], correct_answer: int) -> bool:
138+
"""
139+
Evaluate if the predicted answer matches the correct answer.
140+
"""
141+
if predicted_answer is None:
142+
return False
143+
return predicted_answer == correct_answer
144+
145+
def load_existing_results(filename: str) -> List[Dict]:
146+
"""Load existing results from file if it exists."""
147+
try:
148+
with open(filename, 'r') as f:
149+
return json.load(f)
150+
except FileNotFoundError:
151+
return []
152+
153+
def save_result(filename: str, result: Dict):
154+
"""Save a single result to the results file."""
155+
results = load_existing_results(filename)
156+
results.append(result)
157+
with open(filename, 'w') as f:
158+
json.dump(results, f, indent=2)
159+
160+
def get_last_processed_index(results: List[Dict]) -> int:
161+
"""Get the index of the last processed problem."""
162+
if not results:
163+
return -1
164+
return max(int(r.get('index', -1)) for r in results)
165+
166+
def analyze_results(results: List[Dict]):
167+
"""Analyze and print summary statistics of the results."""
168+
total = len(results)
169+
correct = sum(1 for r in results if r['is_correct'])
170+
accuracy = correct / total if total > 0 else 0
171+
172+
print("\n=== Results Summary ===")
173+
print(f"Total problems: {total}")
174+
print(f"Correct answers: {correct}")
175+
print(f"Accuracy: {accuracy:.2%}")
176+
177+
# Print incorrect problems for analysis
178+
print("\n=== Incorrect Answers ===")
179+
for r in results:
180+
if not r['is_correct']:
181+
print(f"Problem {r['index']}:")
182+
print(f"Expected: {r['correct_answer']}")
183+
print(f"Predicted: {r['predicted_answer']}")
184+
print("---")
185+
186+
def main(model: str):
187+
"""Main evaluation function."""
188+
# Create results directory if it doesn't exist
189+
os.makedirs("results", exist_ok=True)
190+
191+
# Setup results file
192+
results_file = f"evaluation_results_{model.replace('/', '_')}.json"
193+
194+
# Load dataset
195+
dataset = load_2024_dataset()
196+
197+
# Load existing results
198+
existing_results = load_existing_results(results_file)
199+
last_processed_index = get_last_processed_index(existing_results)
200+
201+
# Process problems
202+
for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
203+
if idx <= last_processed_index:
204+
continue
205+
206+
problem_text = item['problem']
207+
correct_answer = int(item['answer'])
208+
209+
# Get model response
210+
response = get_llm_response(problem_text, model)
211+
logger.debug(f"Response: {response}")
212+
predicted_answer = extract_answer(response)
213+
is_correct = evaluate_response(predicted_answer, correct_answer)
214+
215+
# Save result
216+
result = {
217+
"index": idx,
218+
"problem": problem_text,
219+
"model_response": response,
220+
"predicted_answer": predicted_answer,
221+
"correct_answer": correct_answer,
222+
"is_correct": is_correct
223+
}
224+
save_result(results_file, result)
225+
226+
# Optional: Add delay between requests if needed
227+
# time.sleep(5)
228+
229+
# Analyze results
230+
final_results = load_existing_results(results_file)
231+
analyze_results(final_results)
232+
233+
if __name__ == "__main__":
234+
parser = argparse.ArgumentParser(description="Evaluate LLM performance on AIME 2024 problems")
235+
parser.add_argument("--model", type=str, required=True, help="OpenAI model to use (e.g., gpt-4, gpt-3.5-turbo)")
236+
args = parser.parse_args()
237+
238+
main(args.model)

0 commit comments

Comments
 (0)