Skip to content

Commit 76008bd

Browse files
committed
add n responses support
- Add script for RTC eval on arena hard auto - Add ability to evaluate pass@n for AIME bench - Return n samples from proxy when n is set
1 parent d5b468c commit 76008bd

File tree

5 files changed

+501
-96
lines changed

5 files changed

+501
-96
lines changed

optillm.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,52 @@ async def run_approach(approach):
306306
responses, tokens = zip(*results)
307307
return list(responses), sum(tokens)
308308

309+
def execute_n_times(n: int, approaches, operation: str, system_prompt: str, initial_query: str, client: Any, model: str) -> Tuple[Union[str, List[str]], int]:
310+
"""
311+
Execute the pipeline n times and return n responses.
312+
313+
Args:
314+
n (int): Number of times to run the pipeline
315+
approaches (list): List of approaches to execute
316+
operation (str): Operation type ('SINGLE', 'AND', or 'OR')
317+
system_prompt (str): System prompt
318+
initial_query (str): Initial query
319+
client: OpenAI client instance
320+
model (str): Model identifier
321+
322+
Returns:
323+
Tuple[Union[str, List[str]], int]: List of responses and total token count
324+
"""
325+
responses = []
326+
total_tokens = 0
327+
328+
for _ in range(n):
329+
if operation == 'SINGLE':
330+
response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
331+
elif operation == 'AND':
332+
response, tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
333+
elif operation == 'OR':
334+
loop = asyncio.new_event_loop()
335+
asyncio.set_event_loop(loop)
336+
response, tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
337+
loop.close()
338+
else:
339+
raise ValueError(f"Unknown operation: {operation}")
340+
341+
# If response is already a list (from OR operation), extend responses
342+
# Otherwise append the single response
343+
if isinstance(response, list):
344+
responses.extend(response)
345+
else:
346+
responses.append(response)
347+
total_tokens += tokens
348+
349+
# If n=1 and we got a single response, return it as is
350+
# Otherwise return the list of responses
351+
if n == 1 and len(responses) == 1:
352+
return responses[0], total_tokens
353+
return responses, total_tokens
354+
309355
def generate_streaming_response(final_response, model):
310356
# Yield the final response
311357
if isinstance(final_response, list):
@@ -393,11 +439,12 @@ def proxy():
393439
stream = data.get('stream', False)
394440
messages = data.get('messages', [])
395441
model = data.get('model', server_config['model'])
442+
n = data.get('n', server_config['n']) # Get n value from request or config
396443

397444
optillm_approach = data.get('optillm_approach', server_config['approach'])
398445
logger.debug(data)
399446
server_config['mcts_depth'] = data.get('mcts_depth', server_config['mcts_depth'])
400-
server_config['mcts_exploration' ] = data.get('mcts_exploration', server_config['mcts_exploration'])
447+
server_config['mcts_exploration'] = data.get('mcts_exploration', server_config['mcts_exploration'])
401448
server_config['mcts_simulations'] = data.get('mcts_simulations', server_config['mcts_simulations'])
402449

403450
system_prompt, initial_query, message_optillm_approach = parse_conversation(messages)
@@ -428,26 +475,26 @@ def proxy():
428475
contains_none = any(approach == 'none' for approach in approaches)
429476

430477
if operation == 'SINGLE' and approaches[0] == 'none':
431-
# For none approach, return the response directly
432-
result, _ = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
478+
# For none approach with n>1, make n separate calls
479+
if n > 1:
480+
responses = []
481+
completion_tokens = 0
482+
for _ in range(n):
483+
result, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
484+
responses.append(result)
485+
completion_tokens += tokens
486+
result = responses
487+
else:
488+
result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
433489
logger.debug(f'Direct proxy response: {result}')
434490
return jsonify(result), 200
435491

436492
elif operation == 'AND' or operation == 'OR':
437493
if contains_none:
438494
raise ValueError("'none' approach cannot be combined with other approaches")
439495

440-
# Handle non-none approaches
441-
if operation == 'SINGLE':
442-
response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
443-
elif operation == 'AND':
444-
response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
445-
elif operation == 'OR':
446-
loop = asyncio.new_event_loop()
447-
asyncio.set_event_loop(loop)
448-
response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
449-
else:
450-
raise ValueError(f"Unknown operation: {operation}")
496+
# Handle non-none approaches with n attempts
497+
response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model)
451498

452499
except Exception as e:
453500
logger.error(f"Error processing request: {str(e)}")

scripts/eval_aime_benchmark.py

Lines changed: 75 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
import logging
55
import re
66
import time
7-
87
from typing import List, Dict, Tuple, Optional
98
from datetime import datetime
10-
119
from openai import OpenAI
1210
from datasets import load_dataset
1311
from tqdm import tqdm
@@ -17,7 +15,7 @@
1715
logger = logging.getLogger(__name__)
1816

1917
# Initialize OpenAI client
20-
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:8000/v1")
18+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:8888/v1")
2119

2220
SYSTEM_PROMPT = '''You are solving AIME (American Invitational Mathematics Examination) problems.
2321
@@ -48,50 +46,30 @@ def extract_answer(response: str) -> Optional[int]:
4846
"""
4947
Extract the numerical answer from a math solution response.
5048
Handles various formats of boxed answers and falls back to last number if needed.
51-
52-
Args:
53-
response (str): The complete response text from the model
54-
55-
Returns:
56-
Optional[int]: The extracted answer as an integer, or None if no valid answer found
5749
"""
5850
if not response:
5951
return None
6052

61-
# Clean the response: normalize whitespace and handle potential Unicode
53+
# Clean the response
6254
response = ' '.join(response.split())
6355

64-
# List of regex patterns to try, in order of preference
6556
patterns = [
66-
# $n=\boxed{X}$ format
6757
r'\$n=\\boxed{(\d+)}\$',
68-
69-
# LaTeX display style answer: \[\boxed{X}\] or \[\boxed{X}.\]
7058
r'\\\[\\boxed{(\d+)}\\\]',
7159
r'\\\[\\boxed{(\d+)}\.\\\]',
72-
73-
# Inline LaTeX \boxed{X}
7460
r'\\boxed{(\d+)}',
75-
76-
# Common variations
7761
r'\$\\boxed{(\d+)}\$',
7862
r'boxed{(\d+)}',
79-
80-
# Less strict patterns
8163
r'\\boxed\s*{\s*(\d+)\s*}',
8264
r'\bboxed\s*{\s*(\d+)\s*}',
83-
84-
# Plain text answer indicators
8565
r'final answer is[^\d]*(\d+)',
8666
r'answer is[^\d]*(\d+)',
8767
r'answer:[^\d]*(\d+)',
8868
r'= ?(\d+)$'
8969
]
9070

91-
# Try each pattern in order
9271
for pattern in patterns:
9372
matches = re.finditer(pattern, response, re.IGNORECASE)
94-
# Get the last match for this pattern (in case there are multiple)
9573
last_match = None
9674
for match in matches:
9775
last_match = match
@@ -102,47 +80,70 @@ def extract_answer(response: str) -> Optional[int]:
10280
except (ValueError, IndexError):
10381
continue
10482

105-
# Fallback: Extract all numbers and take the last one
106-
# This is our last resort, assuming the answer typically comes last
10783
numbers = re.findall(r'(\d+)', response)
10884
if numbers:
10985
try:
110-
# Convert to int and return the last number found
11186
return int(numbers[-1])
11287
except ValueError:
11388
pass
11489

115-
# If all methods fail, return None
11690
return None
11791

11892
def get_llm_response(problem: str, model: str) -> str:
11993
"""
12094
Get response from the LLM for a given problem.
12195
"""
12296
try:
123-
response = client.chat.completions.create(
97+
response = client.with_options(timeout=1000.0).chat.completions.create(
12498
model=model,
12599
messages=[
126-
# {"role": "system", "content": SYSTEM_PROMPT},
127100
{"role": "user", "content": SYSTEM_PROMPT + problem}
128101
],
129102
max_tokens=8192,
130-
# extra_body={
131-
# "decoding": "entropy_decoding",
132-
# }
133103
)
134104
return response.choices[0].message.content.strip()
135105
except Exception as e:
136106
logger.error(f"Error getting LLM response: {e}")
137107
return ""
138108

139-
def evaluate_response(predicted_answer: Optional[int], correct_answer: int) -> bool:
109+
def make_n_attempts(problem: str, model: str, n: int) -> List[Dict]:
110+
"""
111+
Make n attempts to solve a problem and return all responses and predictions.
112+
113+
Args:
114+
problem (str): The problem text
115+
model (str): The model identifier
116+
n (int): Number of attempts to make
117+
118+
Returns:
119+
List[Dict]: List of dictionaries containing response and predicted answer for each attempt
120+
"""
121+
attempts = []
122+
for i in range(n):
123+
response = get_llm_response(problem, model)
124+
predicted_answer = extract_answer(response)
125+
attempts.append({
126+
"attempt_number": i + 1,
127+
"response": response,
128+
"predicted_answer": predicted_answer
129+
})
130+
return attempts
131+
132+
def evaluate_pass_at_n(attempts: List[Dict], correct_answer: int) -> Tuple[bool, Optional[int]]:
140133
"""
141-
Evaluate if the predicted answer matches the correct answer.
134+
Evaluate if any of the n attempts got the correct answer.
135+
136+
Args:
137+
attempts (List[Dict]): List of attempt results
138+
correct_answer (int): The correct answer
139+
140+
Returns:
141+
Tuple[bool, Optional[int]]: (whether any attempt was correct, first correct attempt number)
142142
"""
143-
if predicted_answer is None:
144-
return False
145-
return predicted_answer == correct_answer
143+
for attempt in attempts:
144+
if attempt["predicted_answer"] == correct_answer:
145+
return True, attempt["attempt_number"]
146+
return False, None
146147

147148
def load_existing_results(filename: str) -> List[Dict]:
148149
"""Load existing results from file if it exists."""
@@ -165,76 +166,84 @@ def get_last_processed_index(results: List[Dict]) -> int:
165166
return -1
166167
return max(int(r.get('index', -1)) for r in results)
167168

168-
def analyze_results(results: List[Dict]):
169-
"""Analyze and print summary statistics of the results."""
169+
def analyze_results(results: List[Dict], n: int):
170+
"""
171+
Analyze and print summary statistics of the results.
172+
173+
Args:
174+
results (List[Dict]): List of evaluation results
175+
n (int): Number of attempts per problem
176+
"""
170177
total = len(results)
171178
correct = sum(1 for r in results if r['is_correct'])
172179
accuracy = correct / total if total > 0 else 0
173180

174181
print("\n=== Results Summary ===")
182+
print(f"Evaluation mode: pass@{n}")
175183
print(f"Total problems: {total}")
176184
print(f"Correct answers: {correct}")
177185
print(f"Accuracy: {accuracy:.2%}")
178186

179-
# Print incorrect problems for analysis
180-
print("\n=== Incorrect Answers ===")
187+
# Calculate attempt statistics
188+
successful_attempts = [r['first_correct_attempt'] for r in results if r['is_correct']]
189+
if successful_attempts:
190+
avg_attempts = sum(successful_attempts) / len(successful_attempts)
191+
print(f"\nFor correct solutions:")
192+
print(f"Average attempts needed: {avg_attempts:.2f}")
193+
print(f"Attempt distribution:")
194+
for i in range(1, n + 1):
195+
count = sum(1 for x in successful_attempts if x == i)
196+
print(f" Attempt {i}: {count} problems")
197+
198+
print("\n=== Incorrect Problems ===")
181199
for r in results:
182200
if not r['is_correct']:
183201
print(f"Problem {r['index']}:")
184202
print(f"Expected: {r['correct_answer']}")
185-
print(f"Predicted: {r['predicted_answer']}")
203+
print("Predicted answers across attempts:", [
204+
attempt['predicted_answer'] for attempt in r['attempts']
205+
])
186206
print("---")
187207

188-
def main(model: str):
208+
def main(model: str, n_attempts: int):
189209
"""Main evaluation function."""
190-
# Create results directory if it doesn't exist
191210
os.makedirs("results", exist_ok=True)
192211

193-
# Setup results file
194-
results_file = f"evaluation_results_{model.replace('/', '_')}.json"
212+
# Include n_attempts in filename to keep separate results for different n values
213+
results_file = f"evaluation_results_{model.replace('/', '_')}_pass_at_{n_attempts}.json"
195214

196-
# Load dataset
197215
dataset = load_2024_dataset()
198-
199-
# Load existing results
200216
existing_results = load_existing_results(results_file)
201217
last_processed_index = get_last_processed_index(existing_results)
202218

203-
# Process problems
204219
for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
205220
if idx <= last_processed_index:
206221
continue
207222

208223
problem_text = item['problem']
209224
correct_answer = int(item['answer'])
210225

211-
# Get model response
212-
response = get_llm_response(problem_text, model)
213-
logger.debug(f"Response: {response}")
214-
predicted_answer = extract_answer(response)
215-
is_correct = evaluate_response(predicted_answer, correct_answer)
226+
# Make n attempts for each problem
227+
attempts = make_n_attempts(problem_text, model, n_attempts)
228+
is_correct, first_correct = evaluate_pass_at_n(attempts, correct_answer)
216229

217-
# Save result
218230
result = {
219231
"index": idx,
220232
"problem": problem_text,
221-
"model_response": response,
222-
"predicted_answer": predicted_answer,
233+
"attempts": attempts,
223234
"correct_answer": correct_answer,
224-
"is_correct": is_correct
235+
"is_correct": is_correct,
236+
"first_correct_attempt": first_correct
225237
}
226238
save_result(results_file, result)
227-
228-
# Optional: Add delay between requests if needed
229-
time.sleep(300)
230239

231-
# Analyze results
232240
final_results = load_existing_results(results_file)
233-
analyze_results(final_results)
241+
analyze_results(final_results, n_attempts)
234242

235243
if __name__ == "__main__":
236244
parser = argparse.ArgumentParser(description="Evaluate LLM performance on AIME 2024 problems")
237245
parser.add_argument("--model", type=str, required=True, help="OpenAI model to use (e.g., gpt-4, gpt-3.5-turbo)")
246+
parser.add_argument("--n", type=int, default=1, help="Number of attempts per problem (for pass@n evaluation)")
238247
args = parser.parse_args()
239248

240-
main(args.model)
249+
main(args.model, args.n)

0 commit comments

Comments
 (0)