Skip to content

Commit f89f89e

Browse files
committed
add eval script
1 parent 98fd9e4 commit f89f89e

File tree

3 files changed

+374
-0
lines changed

3 files changed

+374
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,5 @@ cython_debug/
167167

168168
# VS Code
169169
.vscode/
170+
171+
scripts/results/

scripts/eval_optillmbench.py

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import time
4+
import json
5+
import os
6+
from typing import Dict, List, Any, Tuple
7+
import datasets
8+
from datasets import load_dataset
9+
from openai import OpenAI
10+
import pandas as pd
11+
from tqdm import tqdm
12+
import logging
13+
from datetime import datetime
14+
import re
15+
16+
# Configure logging
17+
logging.basicConfig(
18+
level=logging.INFO,
19+
format='%(asctime)s - %(levelname)s - %(message)s'
20+
)
21+
logger = logging.getLogger(__name__)
22+
23+
# Define the approaches to test
24+
# Each approach is (name, description)
25+
APPROACHES = [
26+
("none", "Baseline without any optimization"),
27+
("leap", "LEAP Approach"),
28+
("rto", "Round Trip Optimization"),
29+
("cot_reflection", "C hain of Thought with Reflection"),
30+
("self_consistency", "Self Consistency Check"),
31+
("plansearch", "Planning with Search"),
32+
("re2", "ReRead Approach"),
33+
("z3", "Z3 Solver for Mathematical Problems"),
34+
("coc", "Chain of Code"),
35+
("executecode" , "Execute Code"),
36+
]
37+
38+
def load_optillm_bench() -> datasets.Dataset:
39+
"""Load the OptILLM Bench dataset."""
40+
try:
41+
dataset = load_dataset("codelion/optillmbench")
42+
return dataset["test"] # We use the test split for evaluation
43+
except Exception as e:
44+
logger.error(f"Error loading dataset: {e}")
45+
raise
46+
47+
def extract_gsm8k_answer(text: str) -> float:
48+
"""Extract numerical answer after ### from GSM8K responses."""
49+
match = re.search(r'###\s*(-?\d*\.?\d+)', text)
50+
if match:
51+
try:
52+
return float(match.group(1))
53+
except ValueError:
54+
return None
55+
return None
56+
57+
def evaluate_response(response: str, ground_truth: str, category: str) -> bool:
58+
"""
59+
Evaluate if the response matches the ground truth based on category.
60+
61+
Args:
62+
response: Model's response
63+
ground_truth: Correct answer
64+
category: Problem category (gsm8k, mmlu_math, boolq, aqua_rat)
65+
66+
Returns:
67+
bool: Whether the response is correct
68+
"""
69+
if not response or not ground_truth:
70+
return False
71+
72+
if category == "gsm8k":
73+
# Extract numerical answers after ### and compare
74+
response_num = extract_gsm8k_answer(response)
75+
ground_truth_num = extract_gsm8k_answer(ground_truth)
76+
77+
if response_num is None or ground_truth_num is None:
78+
return False
79+
80+
# Compare with small tolerance for floating point
81+
return abs(response_num - ground_truth_num) < 1e-6
82+
else:
83+
# For mmlu_math, boolq, and aqua_rat, exact match is required
84+
# Clean up both strings for comparison
85+
response_clean = response.strip().lower()
86+
ground_truth_clean = ground_truth.strip().lower()
87+
return response_clean == ground_truth_clean
88+
89+
def get_prompt_for_category(question: str, category: str) -> str:
90+
"""
91+
Generate appropriate prompt based on category.
92+
"""
93+
if category == "gsm8k":
94+
return (
95+
f"Solve this math problem step by step. After solving, provide the final "
96+
f"numerical answer after '### ' (three hash symbols and a space).\n\n"
97+
f"Question: {question}\n\n"
98+
f"Show your work, then give the final answer after '### '."
99+
)
100+
elif category == "mmlu_math":
101+
return (
102+
f"Solve this math problem. Provide only the answer with no explanation.\n\n"
103+
f"Question: {question}"
104+
)
105+
elif category == "boolq":
106+
return (
107+
f"Answer this yes/no question with only 'yes' or 'no'.\n\n"
108+
f"Question: {question}"
109+
)
110+
elif category == "aqua_rat":
111+
return (
112+
f"Choose the correct answer. Provide only the letter choice with no explanation.\n\n"
113+
f"Question: {question}"
114+
)
115+
else:
116+
return f"Question: {question}"
117+
118+
def evaluate_model(
119+
client: OpenAI,
120+
model: str,
121+
dataset: datasets.Dataset,
122+
approach: str,
123+
max_samples: int = None
124+
) -> Tuple[Dict[str, float], List[Dict[str, Any]]]:
125+
"""
126+
Evaluate a model on the dataset using a specific approach.
127+
Returns metrics and detailed results.
128+
"""
129+
metrics = {
130+
"total_correct": 0,
131+
"total_time": 0,
132+
"samples": 0,
133+
}
134+
135+
# Initialize category-specific metrics
136+
category_metrics = {}
137+
138+
# Detailed results for each example
139+
detailed_results = []
140+
141+
# Prepare the dataset
142+
examples = dataset if max_samples is None else dataset.select(range(max_samples))
143+
144+
# Create model name with approach
145+
full_model_name = f"{approach}-{model}" if approach != "none" else model
146+
147+
for example in tqdm(examples, desc=f"Evaluating {approach}"):
148+
try:
149+
# Get appropriate prompt for the category
150+
prompt = get_prompt_for_category(example['question'], example['category'])
151+
152+
# Record start time
153+
start_time = time.time()
154+
155+
# Make API call
156+
response = client.chat.completions.create(
157+
model=full_model_name,
158+
messages=[
159+
{"role": "system", "content": "You are a helpful AI assistant focused on providing precise answers in the requested format."},
160+
{"role": "user", "content": prompt}
161+
],
162+
temperature=0.2,
163+
max_tokens=4096
164+
)
165+
166+
# Calculate time taken
167+
time_taken = time.time() - start_time
168+
169+
# Get the response text
170+
response_text = response.choices[0].message.content
171+
172+
# Evaluate the response
173+
is_correct = evaluate_response(
174+
response_text,
175+
example['answer'],
176+
example['category']
177+
)
178+
179+
# Update metrics
180+
metrics["total_correct"] += int(is_correct)
181+
metrics["total_time"] += time_taken
182+
metrics["samples"] += 1
183+
184+
# Update category metrics
185+
if example['category'] not in category_metrics:
186+
category_metrics[example['category']] = {
187+
"correct": 0,
188+
"total": 0,
189+
"time": 0
190+
}
191+
category_metrics[example['category']]["correct"] += int(is_correct)
192+
category_metrics[example['category']]["total"] += 1
193+
category_metrics[example['category']]["time"] += time_taken
194+
195+
# Record detailed result
196+
detailed_results.append({
197+
"id": example['id'],
198+
"category": example['category'],
199+
"correct": is_correct,
200+
"time_taken": time_taken,
201+
"response": response_text,
202+
"ground_truth": example['answer']
203+
})
204+
205+
except Exception as e:
206+
logger.error(f"Error processing example {example['id']}: {e}")
207+
continue
208+
209+
# Calculate final metrics
210+
final_metrics = {
211+
"accuracy": metrics["total_correct"] / metrics["samples"] if metrics["samples"] > 0 else 0,
212+
"average_time": metrics["total_time"] / metrics["samples"] if metrics["samples"] > 0 else 0,
213+
"total_time": metrics["total_time"],
214+
"total_samples": metrics["samples"],
215+
}
216+
217+
# Add category-specific metrics
218+
for category, cat_metrics in category_metrics.items():
219+
final_metrics[f"{category}_accuracy"] = cat_metrics["correct"] / cat_metrics["total"]
220+
final_metrics[f"{category}_average_time"] = cat_metrics["time"] / cat_metrics["total"]
221+
222+
return final_metrics, detailed_results
223+
224+
def save_results(metrics: Dict[str, float], detailed_results: List[Dict[str, Any]],
225+
model: str, approach: str, output_dir: str):
226+
"""Save evaluation results to files."""
227+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
228+
229+
# Create model-specific directory
230+
model_dir = os.path.join(output_dir, model.replace('/', '_'))
231+
os.makedirs(model_dir, exist_ok=True)
232+
233+
base_filename = os.path.join(model_dir, f"{approach}_{timestamp}")
234+
235+
# Save metrics
236+
with open(f"{base_filename}_metrics.json", "w") as f:
237+
json.dump(metrics, f, indent=2)
238+
239+
# Save detailed results
240+
with open(f"{base_filename}_detailed.json", "w") as f:
241+
json.dump(detailed_results, f, indent=2)
242+
243+
# Create a summary DataFrame for easier analysis
244+
df = pd.DataFrame(detailed_results)
245+
df.to_csv(f"{base_filename}_summary.csv", index=False)
246+
247+
logger.info(f"Results saved to {base_filename}_*")
248+
249+
def generate_report(all_metrics: Dict[str, Dict[str, float]], output_dir: str):
250+
"""Generate a comprehensive report comparing all approaches."""
251+
report = []
252+
253+
# Header
254+
report.append("# OptILLM Bench Evaluation Report")
255+
report.append(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
256+
257+
# Overall Results Table
258+
report.append("## Overall Results")
259+
headers = ["Approach", "Accuracy", "Avg Time (s)", "Total Time (s)"]
260+
rows = []
261+
262+
for approach, metrics in all_metrics.items():
263+
rows.append([
264+
approach,
265+
f"{metrics['accuracy']*100:.2f}%",
266+
f"{metrics['average_time']:.2f}",
267+
f"{metrics['total_time']:.2f}"
268+
])
269+
270+
# Convert to DataFrame for nice formatting
271+
df = pd.DataFrame(rows, columns=headers)
272+
report.append(df.to_markdown())
273+
274+
# Category-wise Results
275+
report.append("\n## Results by Category")
276+
categories = ["gsm8k", "mmlu_math", "boolq", "aqua_rat"]
277+
278+
for category in categories:
279+
report.append(f"\n### {category.upper()}")
280+
headers = ["Approach", "Accuracy", "Avg Time (s)"]
281+
rows = []
282+
283+
for approach, metrics in all_metrics.items():
284+
if f"{category}_accuracy" in metrics:
285+
rows.append([
286+
approach,
287+
f"{metrics[f'{category}_accuracy']*100:.2f}%",
288+
f"{metrics[f'{category}_average_time']:.2f}"
289+
])
290+
291+
df = pd.DataFrame(rows, columns=headers)
292+
report.append(df.to_markdown())
293+
294+
# Save report
295+
report_path = f"{output_dir}/evaluation_report.md"
296+
with open(report_path, "w") as f:
297+
f.write("\n\n".join(report))
298+
299+
logger.info(f"Report saved to {report_path}")
300+
301+
def main():
302+
parser = argparse.ArgumentParser(description="Evaluate a model on OptILLM Bench")
303+
parser.add_argument("--model", required=True, help="Model identifier")
304+
parser.add_argument("--base-url", default="http://localhost:8000/v1",
305+
help="Base URL for API endpoint")
306+
parser.add_argument("--max-samples", type=int, help="Maximum number of samples to evaluate")
307+
parser.add_argument("--output-dir", default="results",
308+
help="Directory to save results")
309+
parser.add_argument("--approaches", nargs="+",
310+
help="Specific approaches to evaluate (default: all)")
311+
args = parser.parse_args()
312+
313+
# Create output directory
314+
os.makedirs(args.output_dir, exist_ok=True)
315+
316+
# Get API key from environment
317+
api_key = os.environ.get("OPENAI_API_KEY")
318+
if not api_key:
319+
raise ValueError("OPENAI_API_KEY environment variable must be set")
320+
321+
# Initialize OpenAI client
322+
client = OpenAI(
323+
api_key=api_key,
324+
base_url=args.base_url
325+
)
326+
327+
# Load dataset
328+
dataset = load_optillm_bench()
329+
330+
# Determine which approaches to evaluate
331+
approaches_to_test = (
332+
[a[0] for a in APPROACHES if a[0] in args.approaches]
333+
if args.approaches
334+
else [a[0] for a in APPROACHES]
335+
)
336+
337+
# Store all metrics for final report
338+
all_metrics = {}
339+
340+
# Evaluate each approach
341+
for approach in approaches_to_test:
342+
logger.info(f"Evaluating approach: {approach}")
343+
344+
try:
345+
metrics, detailed_results = evaluate_model(
346+
client,
347+
args.model,
348+
dataset,
349+
approach,
350+
args.max_samples
351+
)
352+
353+
all_metrics[approach] = metrics
354+
355+
# Save results for this approach
356+
save_results(metrics, detailed_results, args.model, approach,
357+
args.output_dir)
358+
359+
logger.info(f"Completed evaluation for {approach}")
360+
logger.info(f"Accuracy: {metrics['accuracy']*100:.2f}%")
361+
logger.info(f"Average time per sample: {metrics['average_time']:.2f}s")
362+
363+
except Exception as e:
364+
logger.error(f"Error evaluating approach {approach}: {e}")
365+
continue
366+
367+
# Generate final report
368+
generate_report(all_metrics, args.output_dir)
369+
370+
if __name__ == "__main__":
371+
main()

scripts/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
tabulate
12
datasets
23
accelerate
34
huggingface_hub

0 commit comments

Comments
 (0)