Skip to content

Commit 572a0c5

Browse files
authored
Merge pull request #98 from codelion/add-aime-eval-script
Add aime eval script
2 parents 97eb708 + 1b1b268 commit 572a0c5

File tree

3 files changed

+251
-3
lines changed

3 files changed

+251
-3
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ response = client.chat.completions.create(
191191
- e.g. for llama.cpp, run `python3 optillm.py --base_url http://localhost:8080/v1`
192192

193193
> [!WARNING]
194-
> Note that llama-server (and ollama) currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
195-
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`. Use the built-in local inference server to use these approaches.
194+
> Note that the Anthropic API, llama-server (and ollama) currently does not support sampling multiple responses from a model, which limits the available approaches to the following:
195+
> `cot_reflection`, `leap`, `plansearch`, `rstar`, `rto`, `self_consistency`, `re2`, and `z3`. For models on HuggingFace, you can use the built-in local inference server as it supports multiple responses.
196196
197197
## Implemented techniques
198198

optillm/plugins/executecode_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88

99
SLUG = "executecode"
1010

11+
EXECUTE_CODE_PROMPT = '''Generate Python code to solve this problem. Put the code in a ```python block. The code:
12+
1. Should use standard Python libraries (math, itertools, etc.)
13+
2. Should print the final answer
14+
3. Should be complete and runnable
15+
4. Should include example test cases if relevant
16+
17+
The code will be automatically executed when submitted.'''
18+
1119
def extract_python_code(text: str) -> List[str]:
1220
"""Extract Python code blocks from text."""
1321
# print(f"Extracting code: {text}")
@@ -78,7 +86,7 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
7886
else:
7987
# Get initial response from the model
8088
messages = [
81-
{"role": "system", "content": system_prompt},
89+
{"role": "system", "content": system_prompt + EXECUTE_CODE_PROMPT} ,
8290
{"role": "user", "content": initial_query}
8391
]
8492

scripts/eval_aime_benchmark.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import argparse
2+
import json
3+
import os
4+
import logging
5+
import re
6+
import time
7+
8+
from typing import List, Dict, Tuple, Optional
9+
from datetime import datetime
10+
11+
from openai import OpenAI
12+
from datasets import load_dataset
13+
from tqdm import tqdm
14+
15+
# Configure logging
16+
logging.basicConfig(level=logging.INFO)
17+
logger = logging.getLogger(__name__)
18+
19+
# Initialize OpenAI client
20+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:8000/v1")
21+
22+
SYSTEM_PROMPT = '''You are solving AIME (American Invitational Mathematics Examination) problems.
23+
24+
Important: Always end your solution with the final answer in one of these two formats:
25+
26+
1. \\[
27+
\\boxed{X}.
28+
\\]
29+
30+
2. $n=\\boxed{X}$
31+
32+
where X is your integer answer between 0 and 999.'''
33+
34+
def load_2024_dataset() -> list[dict]:
35+
"""
36+
Load the dataset of problems.
37+
Returns:
38+
list[dict]: The dataset of problems.
39+
"""
40+
dataset_original = load_dataset("AI-MO/aimo-validation-aime")
41+
# Filter out problems that are not from 2024
42+
dataset = dataset_original["train"].filter(lambda example: "2024" in example["url"])
43+
logging.debug(f"Filtered dataset size: {len(dataset)}.")
44+
assert len(dataset) == 30, f"Expected 30 problems after filtering by 2024, but found {len(dataset)}"
45+
return dataset
46+
47+
def extract_answer(response: str) -> Optional[int]:
48+
"""
49+
Extract the numerical answer from a math solution response.
50+
Handles various formats of boxed answers and falls back to last number if needed.
51+
52+
Args:
53+
response (str): The complete response text from the model
54+
55+
Returns:
56+
Optional[int]: The extracted answer as an integer, or None if no valid answer found
57+
"""
58+
if not response:
59+
return None
60+
61+
# Clean the response: normalize whitespace and handle potential Unicode
62+
response = ' '.join(response.split())
63+
64+
# List of regex patterns to try, in order of preference
65+
patterns = [
66+
# $n=\boxed{X}$ format
67+
r'\$n=\\boxed{(\d+)}\$',
68+
69+
# LaTeX display style answer: \[\boxed{X}\] or \[\boxed{X}.\]
70+
r'\\\[\\boxed{(\d+)}\\\]',
71+
r'\\\[\\boxed{(\d+)}\.\\\]',
72+
73+
# Inline LaTeX \boxed{X}
74+
r'\\boxed{(\d+)}',
75+
76+
# Common variations
77+
r'\$\\boxed{(\d+)}\$',
78+
r'boxed{(\d+)}',
79+
80+
# Less strict patterns
81+
r'\\boxed\s*{\s*(\d+)\s*}',
82+
r'\bboxed\s*{\s*(\d+)\s*}',
83+
84+
# Plain text answer indicators
85+
r'final answer is[^\d]*(\d+)',
86+
r'answer is[^\d]*(\d+)',
87+
r'answer:[^\d]*(\d+)',
88+
r'= ?(\d+)$'
89+
]
90+
91+
# Try each pattern in order
92+
for pattern in patterns:
93+
matches = re.finditer(pattern, response, re.IGNORECASE)
94+
# Get the last match for this pattern (in case there are multiple)
95+
last_match = None
96+
for match in matches:
97+
last_match = match
98+
99+
if last_match:
100+
try:
101+
return int(last_match.group(1))
102+
except (ValueError, IndexError):
103+
continue
104+
105+
# Fallback: Extract all numbers and take the last one
106+
# This is our last resort, assuming the answer typically comes last
107+
numbers = re.findall(r'(\d+)', response)
108+
if numbers:
109+
try:
110+
# Convert to int and return the last number found
111+
return int(numbers[-1])
112+
except ValueError:
113+
pass
114+
115+
# If all methods fail, return None
116+
return None
117+
118+
def get_llm_response(problem: str, model: str) -> str:
119+
"""
120+
Get response from the LLM for a given problem.
121+
"""
122+
try:
123+
response = client.chat.completions.create(
124+
model=model,
125+
messages=[
126+
# {"role": "system", "content": SYSTEM_PROMPT},
127+
{"role": "user", "content": SYSTEM_PROMPT + problem}
128+
],
129+
max_tokens=8192,
130+
# extra_body={
131+
# "decoding": "entropy_decoding",
132+
# }
133+
)
134+
return response.choices[0].message.content.strip()
135+
except Exception as e:
136+
logger.error(f"Error getting LLM response: {e}")
137+
return ""
138+
139+
def evaluate_response(predicted_answer: Optional[int], correct_answer: int) -> bool:
140+
"""
141+
Evaluate if the predicted answer matches the correct answer.
142+
"""
143+
if predicted_answer is None:
144+
return False
145+
return predicted_answer == correct_answer
146+
147+
def load_existing_results(filename: str) -> List[Dict]:
148+
"""Load existing results from file if it exists."""
149+
try:
150+
with open(filename, 'r') as f:
151+
return json.load(f)
152+
except FileNotFoundError:
153+
return []
154+
155+
def save_result(filename: str, result: Dict):
156+
"""Save a single result to the results file."""
157+
results = load_existing_results(filename)
158+
results.append(result)
159+
with open(filename, 'w') as f:
160+
json.dump(results, f, indent=2)
161+
162+
def get_last_processed_index(results: List[Dict]) -> int:
163+
"""Get the index of the last processed problem."""
164+
if not results:
165+
return -1
166+
return max(int(r.get('index', -1)) for r in results)
167+
168+
def analyze_results(results: List[Dict]):
169+
"""Analyze and print summary statistics of the results."""
170+
total = len(results)
171+
correct = sum(1 for r in results if r['is_correct'])
172+
accuracy = correct / total if total > 0 else 0
173+
174+
print("\n=== Results Summary ===")
175+
print(f"Total problems: {total}")
176+
print(f"Correct answers: {correct}")
177+
print(f"Accuracy: {accuracy:.2%}")
178+
179+
# Print incorrect problems for analysis
180+
print("\n=== Incorrect Answers ===")
181+
for r in results:
182+
if not r['is_correct']:
183+
print(f"Problem {r['index']}:")
184+
print(f"Expected: {r['correct_answer']}")
185+
print(f"Predicted: {r['predicted_answer']}")
186+
print("---")
187+
188+
def main(model: str):
189+
"""Main evaluation function."""
190+
# Create results directory if it doesn't exist
191+
os.makedirs("results", exist_ok=True)
192+
193+
# Setup results file
194+
results_file = f"evaluation_results_{model.replace('/', '_')}.json"
195+
196+
# Load dataset
197+
dataset = load_2024_dataset()
198+
199+
# Load existing results
200+
existing_results = load_existing_results(results_file)
201+
last_processed_index = get_last_processed_index(existing_results)
202+
203+
# Process problems
204+
for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
205+
if idx <= last_processed_index:
206+
continue
207+
208+
problem_text = item['problem']
209+
correct_answer = int(item['answer'])
210+
211+
# Get model response
212+
response = get_llm_response(problem_text, model)
213+
logger.debug(f"Response: {response}")
214+
predicted_answer = extract_answer(response)
215+
is_correct = evaluate_response(predicted_answer, correct_answer)
216+
217+
# Save result
218+
result = {
219+
"index": idx,
220+
"problem": problem_text,
221+
"model_response": response,
222+
"predicted_answer": predicted_answer,
223+
"correct_answer": correct_answer,
224+
"is_correct": is_correct
225+
}
226+
save_result(results_file, result)
227+
228+
# Optional: Add delay between requests if needed
229+
time.sleep(300)
230+
231+
# Analyze results
232+
final_results = load_existing_results(results_file)
233+
analyze_results(final_results)
234+
235+
if __name__ == "__main__":
236+
parser = argparse.ArgumentParser(description="Evaluate LLM performance on AIME 2024 problems")
237+
parser.add_argument("--model", type=str, required=True, help="OpenAI model to use (e.g., gpt-4, gpt-3.5-turbo)")
238+
args = parser.parse_args()
239+
240+
main(args.model)

0 commit comments

Comments
 (0)