Skip to content

Commit 6baa4bc

Browse files
committed
Add default test-time compute evaluation modes
Introduces a default test-time compute configuration with pass@1, maj@64, and genselect@64 approaches for standard evaluation. Updates evaluation logic to support multiple runs for pass@1, adjusts report generation to highlight these modes, and refactors main() to use the new defaults when no approaches are specified.
1 parent 791df94 commit 6baa4bc

File tree

1 file changed

+218
-69
lines changed

1 file changed

+218
-69
lines changed

scripts/eval_optillmbench.py

Lines changed: 218 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@
6767
("majority_voting_9", "Majority Voting with k=9", {"k": 9}),
6868
]
6969

70+
# Default test-time compute configuration for standard evaluation
71+
DEFAULT_TEST_TIME_COMPUTE = [
72+
("pass@1", "Baseline with 64 runs averaged", {"num_runs": 64}),
73+
("maj@64", "Majority Voting with k=64", {"approach": "majority_voting", "k": 64}),
74+
("genselect@64", "GenSelect with 64 candidates", {"approach": "genselect", "num_candidates": 64})
75+
]
76+
7077
def load_optillm_bench() -> datasets.Dataset:
7178
"""Load the OptiLLM Bench dataset."""
7279
try:
@@ -318,67 +325,96 @@ def evaluate_model(
318325
# Prepare the dataset
319326
examples = dataset if max_samples is None else dataset.select(range(max_samples))
320327

328+
# Check if we need to do multiple runs (for pass@1 calculation)
329+
num_runs = approach_extra_body.get("num_runs", 1) if approach_extra_body else 1
330+
331+
# Handle special approach names
332+
actual_approach = approach
333+
if approach == "pass@1":
334+
actual_approach = "none"
335+
elif approach == "maj@64":
336+
actual_approach = "majority_voting"
337+
elif approach == "genselect@64":
338+
actual_approach = "genselect"
339+
elif approach_extra_body and "approach" in approach_extra_body:
340+
actual_approach = approach_extra_body["approach"]
341+
321342
# Create model name with approach - handle special cases
322-
if approach == "none":
343+
if actual_approach == "none":
323344
full_model_name = model
324-
elif approach.startswith("thinkdeeper_"):
345+
elif actual_approach.startswith("thinkdeeper_"):
325346
# For thinkdeeper, use base model name (decoding is passed in extra_body)
326347
full_model_name = model
327-
elif approach.startswith("majority_voting_"):
348+
elif actual_approach.startswith("majority_voting"):
328349
# For majority voting, use majority_voting prefix
329350
full_model_name = f"majority_voting-{model}"
330351
else:
331352
# Standard approach prefix
332-
full_model_name = f"{approach}-{model}"
353+
full_model_name = f"{actual_approach}-{model}"
333354

334355
for example in tqdm(examples, desc=f"Evaluating {approach}"):
335-
try:
336-
# Get appropriate prompt for the category
337-
prompt = get_prompt_for_category(example['question'], example['category'])
338-
339-
# Record start time
340-
start_time = time.time()
341-
342-
# Prepare extra_body parameters
343-
extra_body = {"spl_learning": False}
344-
if approach_extra_body:
345-
extra_body.update(approach_extra_body)
346-
347-
# Make API call
348-
response = client.chat.completions.create(
349-
model=full_model_name,
350-
messages=[
351-
{"role": "system", "content": "You are a helpful AI assistant focused on providing precise answers in the requested format."},
352-
{"role": "user", "content": prompt}
353-
],
354-
temperature=0.2,
355-
max_tokens=4096,
356-
extra_body=extra_body,
357-
)
358-
359-
# Calculate time taken
360-
time_taken = time.time() - start_time
356+
# For pass@1, we need to run multiple times and calculate average
357+
if num_runs > 1:
358+
run_results = []
359+
total_run_time = 0
361360

362-
# Get the response text
363-
response_text = response.choices[0].message.content
361+
for run_idx in range(num_runs):
362+
try:
363+
# Get appropriate prompt for the category
364+
prompt = get_prompt_for_category(example['question'], example['category'])
365+
366+
# Record start time
367+
start_time = time.time()
368+
369+
# Prepare extra_body parameters (excluding num_runs)
370+
extra_body = {"spl_learning": False}
371+
if approach_extra_body:
372+
extra_body_clean = {k: v for k, v in approach_extra_body.items() if k != "num_runs"}
373+
extra_body.update(extra_body_clean)
374+
375+
# Make API call
376+
response = client.chat.completions.create(
377+
model=full_model_name,
378+
messages=[
379+
{"role": "system", "content": "You are a helpful AI assistant focused on providing precise answers in the requested format."},
380+
{"role": "user", "content": prompt}
381+
],
382+
temperature=0.7, # Higher temperature for pass@k diversity
383+
max_tokens=4096,
384+
extra_body=extra_body,
385+
)
386+
387+
# Calculate time taken
388+
time_taken = time.time() - start_time
389+
total_run_time += time_taken
390+
391+
# Get the response text
392+
response_text = response.choices[0].message.content
393+
394+
# Process the response to remove thinking blocks
395+
processed_response = remove_thinking_blocks(response_text)
396+
397+
# Evaluate the processed response
398+
is_correct = evaluate_response(
399+
processed_response,
400+
example['answer'],
401+
example['category'],
402+
example['question']
403+
)
404+
405+
run_results.append(is_correct)
406+
407+
except Exception as e:
408+
logger.error(f"Error in run {run_idx+1} for example {example['id']}: {e}")
409+
run_results.append(False)
364410

365-
# Also store the raw response for reference
366-
raw_response = response_text
411+
# Calculate average success rate for this example
412+
success_rate = sum(run_results) / len(run_results) if run_results else 0
413+
avg_time = total_run_time / len(run_results) if run_results else 0
367414

368-
# Process the response to remove thinking blocks
369-
processed_response = remove_thinking_blocks(response_text)
370-
371-
# Evaluate the processed response
372-
is_correct = evaluate_response(
373-
processed_response,
374-
example['answer'],
375-
example['category'],
376-
example['question'] # Pass the question for MMLU evaluation
377-
)
378-
379-
# Update metrics
380-
metrics["total_correct"] += int(is_correct)
381-
metrics["total_time"] += time_taken
415+
# Update metrics with average
416+
metrics["total_correct"] += success_rate
417+
metrics["total_time"] += avg_time
382418
metrics["samples"] += 1
383419

384420
# Update category metrics
@@ -388,28 +424,101 @@ def evaluate_model(
388424
"total": 0,
389425
"time": 0
390426
}
391-
category_metrics[example['category']]["correct"] += int(is_correct)
427+
category_metrics[example['category']]["correct"] += success_rate
392428
category_metrics[example['category']]["total"] += 1
393-
category_metrics[example['category']]["time"] += time_taken
394-
395-
# Check if thinking blocks were removed
396-
has_thinking = '</think>' in raw_response
429+
category_metrics[example['category']]["time"] += avg_time
397430

398431
# Record detailed result
399432
detailed_results.append({
400433
"id": example['id'],
401434
"category": example['category'],
402-
"correct": is_correct,
403-
"time_taken": time_taken,
404-
"raw_response": raw_response,
405-
"processed_response": processed_response if has_thinking else None,
406-
"has_thinking": has_thinking,
435+
"correct": success_rate, # Store success rate instead of boolean
436+
"num_runs": num_runs,
437+
"successes": sum(run_results),
438+
"time_taken": avg_time,
407439
"ground_truth": example['answer']
408440
})
409441

410-
except Exception as e:
411-
logger.error(f"Error processing example {example['id']}: {e}")
412-
continue
442+
else:
443+
# Single run (original logic)
444+
try:
445+
# Get appropriate prompt for the category
446+
prompt = get_prompt_for_category(example['question'], example['category'])
447+
448+
# Record start time
449+
start_time = time.time()
450+
451+
# Prepare extra_body parameters
452+
extra_body = {"spl_learning": False}
453+
if approach_extra_body:
454+
extra_body.update(approach_extra_body)
455+
456+
# Make API call
457+
response = client.chat.completions.create(
458+
model=full_model_name,
459+
messages=[
460+
{"role": "system", "content": "You are a helpful AI assistant focused on providing precise answers in the requested format."},
461+
{"role": "user", "content": prompt}
462+
],
463+
temperature=0.2,
464+
max_tokens=4096,
465+
extra_body=extra_body,
466+
)
467+
468+
# Calculate time taken
469+
time_taken = time.time() - start_time
470+
471+
# Get the response text
472+
response_text = response.choices[0].message.content
473+
474+
# Also store the raw response for reference
475+
raw_response = response_text
476+
477+
# Process the response to remove thinking blocks
478+
processed_response = remove_thinking_blocks(response_text)
479+
480+
# Evaluate the processed response
481+
is_correct = evaluate_response(
482+
processed_response,
483+
example['answer'],
484+
example['category'],
485+
example['question'] # Pass the question for MMLU evaluation
486+
)
487+
488+
# Update metrics
489+
metrics["total_correct"] += int(is_correct)
490+
metrics["total_time"] += time_taken
491+
metrics["samples"] += 1
492+
493+
# Update category metrics
494+
if example['category'] not in category_metrics:
495+
category_metrics[example['category']] = {
496+
"correct": 0,
497+
"total": 0,
498+
"time": 0
499+
}
500+
category_metrics[example['category']]["correct"] += int(is_correct)
501+
category_metrics[example['category']]["total"] += 1
502+
category_metrics[example['category']]["time"] += time_taken
503+
504+
# Check if thinking blocks were removed
505+
has_thinking = '</think>' in raw_response
506+
507+
# Record detailed result
508+
detailed_results.append({
509+
"id": example['id'],
510+
"category": example['category'],
511+
"correct": is_correct,
512+
"time_taken": time_taken,
513+
"raw_response": raw_response,
514+
"processed_response": processed_response if has_thinking else None,
515+
"has_thinking": has_thinking,
516+
"ground_truth": example['answer']
517+
})
518+
519+
except Exception as e:
520+
logger.error(f"Error processing example {example['id']}: {e}")
521+
continue
413522

414523
# Calculate final metrics
415524
final_metrics = {
@@ -458,12 +567,27 @@ def generate_report(all_metrics: Dict[str, Dict[str, float]], output_dir: str, i
458567
"""Generate a comprehensive report comparing all approaches."""
459568
report = []
460569

570+
# Check if this is the default test-time compute evaluation
571+
is_default_test_time = set(all_metrics.keys()) == {"pass@1", "maj@64", "genselect@64"}
572+
461573
# Header
462-
report_title = "OptiLLM Bench Test-Time Compute Evaluation Report" if is_test_time_compute else "OptiLLM Bench Evaluation Report"
574+
if is_default_test_time:
575+
report_title = "OptiLLM Bench Test-Time Compute Evaluation Report"
576+
elif is_test_time_compute:
577+
report_title = "OptiLLM Bench Test-Time Compute Scaling Report"
578+
else:
579+
report_title = "OptiLLM Bench Evaluation Report"
580+
463581
report.append(f"# {report_title}")
464582
report.append(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
465583

466-
if is_test_time_compute:
584+
if is_default_test_time:
585+
report.append("## Test-Time Compute Evaluation Results\n")
586+
report.append("This report evaluates the potential of test-time compute with:")
587+
report.append("- **pass@1**: Baseline averaged over 64 runs (measures consistency)")
588+
report.append("- **maj@64**: Majority voting with 64 candidates")
589+
report.append("- **genselect@64**: Generative selection with 64 candidates\n")
590+
elif is_test_time_compute:
467591
report.append("This report evaluates test-time compute scaling approaches:")
468592
report.append("- **Sequential scaling**: ThinkDeeper with varying thinking token budgets")
469593
report.append("- **Parallel scaling**: Majority voting with varying k values\n")
@@ -505,6 +629,28 @@ def generate_report(all_metrics: Dict[str, Dict[str, float]], output_dir: str, i
505629
df = pd.DataFrame(rows, columns=headers)
506630
report.append(df.to_markdown())
507631

632+
# Add summary section for default test-time compute
633+
if is_default_test_time:
634+
report.append("\n## Summary")
635+
if "pass@1" in all_metrics and "maj@64" in all_metrics and "genselect@64" in all_metrics:
636+
pass1_acc = all_metrics["pass@1"]["accuracy"] * 100
637+
maj64_acc = all_metrics["maj@64"]["accuracy"] * 100
638+
genselect64_acc = all_metrics["genselect@64"]["accuracy"] * 100
639+
640+
report.append(f"\n**Key Metrics:**")
641+
report.append(f"- **pass@1** (baseline averaged over 64 runs): {pass1_acc:.2f}%")
642+
report.append(f"- **maj@64** (majority voting with 64 candidates): {maj64_acc:.2f}%")
643+
report.append(f"- **genselect@64** (quality-based selection from 64 candidates): {genselect64_acc:.2f}%")
644+
645+
# Calculate improvements
646+
if pass1_acc > 0:
647+
maj_improvement = ((maj64_acc - pass1_acc) / pass1_acc) * 100
648+
genselect_improvement = ((genselect64_acc - pass1_acc) / pass1_acc) * 100
649+
650+
report.append(f"\n**Improvements over pass@1:**")
651+
report.append(f"- maj@64: {'+' if maj_improvement > 0 else ''}{maj_improvement:.1f}%")
652+
report.append(f"- genselect@64: {'+' if genselect_improvement > 0 else ''}{genselect_improvement:.1f}%")
653+
508654
# Save report
509655
report_path = f"{output_dir}/evaluation_report.md"
510656
with open(report_path, "w") as f:
@@ -555,12 +701,13 @@ def main():
555701
if args.approaches:
556702
# Filter test-time compute approaches if specific ones are requested
557703
approaches_config = [a for a in TEST_TIME_COMPUTE_APPROACHES if a[0] in args.approaches]
704+
elif args.approaches:
705+
# Specific approaches requested
706+
approaches_config = [a for a in APPROACHES if a[0] in args.approaches]
558707
else:
559-
# Use standard approaches
560-
if args.approaches:
561-
approaches_config = [a for a in APPROACHES if a[0] in args.approaches]
562-
else:
563-
approaches_config = APPROACHES
708+
# Default: Use the default test-time compute configuration
709+
approaches_config = DEFAULT_TEST_TIME_COMPUTE
710+
logger.info("Using default test-time compute evaluation (pass@1, maj@64, genselect@64)")
564711

565712
# Store all metrics for final report
566713
all_metrics = {}
@@ -596,7 +743,9 @@ def main():
596743
continue
597744

598745
# Generate final report
599-
generate_report(all_metrics, args.output_dir, args.test_time_compute)
746+
# Determine if we're using default test-time compute or explicit test-time compute
747+
is_test_time = args.test_time_compute or (not args.approaches and approaches_config == DEFAULT_TEST_TIME_COMPUTE)
748+
generate_report(all_metrics, args.output_dir, is_test_time)
600749

601750
if __name__ == "__main__":
602751
main()

0 commit comments

Comments
 (0)