Skip to content

Commit 0a87351

Browse files
authored
Merge pull request #86 from statxc/feat/benchmark-harness
feat: add end-to-end benchmark harness for running generation and eva…
2 parents b88c28d + 9c15950 commit 0a87351

File tree

4 files changed

+917
-0
lines changed

4 files changed

+917
-0
lines changed

configs/eval/benchmark.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Benchmark harness defaults
2+
# Override via CLI flags or custom config file
3+
4+
vlm:
5+
provider: gemini
6+
model: gemini-2.0-flash
7+
8+
image:
9+
provider: google_imagen
10+
model: gemini-3-pro-image-preview
11+
12+
pipeline:
13+
refinement_iterations: 3
14+
optimize_inputs: false
15+
auto_refine: false
16+
17+
output:
18+
dir: outputs
19+
format: png
20+
save_iterations: true
21+
save_prompts: false

paperbanana/cli.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,161 @@ async def _run():
11241124
)
11251125

11261126

1127+
@app.command()
1128+
def benchmark(
1129+
config: Optional[str] = typer.Option(None, "--config", help="Path to config YAML file"),
1130+
output_dir: Optional[str] = typer.Option(
1131+
None, "--output-dir", "-o", help="Output directory for benchmark run"
1132+
),
1133+
vlm_provider: Optional[str] = typer.Option(None, "--vlm-provider", help="VLM provider"),
1134+
vlm_model: Optional[str] = typer.Option(None, "--vlm-model", help="VLM model name"),
1135+
image_provider: Optional[str] = typer.Option(
1136+
None, "--image-provider", help="Image gen provider"
1137+
),
1138+
image_model: Optional[str] = typer.Option(None, "--image-model", help="Image gen model name"),
1139+
iterations: Optional[int] = typer.Option(
1140+
None, "--iterations", "-n", help="Refinement iterations per entry"
1141+
),
1142+
auto: bool = typer.Option(False, "--auto", help="Loop until critic satisfied per entry"),
1143+
optimize: bool = typer.Option(False, "--optimize", help="Preprocess inputs per entry"),
1144+
category: Optional[str] = typer.Option(
1145+
None, "--category", help="Only run entries in this category"
1146+
),
1147+
ids: Optional[str] = typer.Option(
1148+
None, "--ids", help="Comma-separated entry IDs to run (e.g., 2601.03570v1,2601.05110v1)"
1149+
),
1150+
limit: Optional[int] = typer.Option(None, "--limit", help="Max number of entries to process"),
1151+
eval_only: Optional[str] = typer.Option(
1152+
None,
1153+
"--eval-only",
1154+
help="Skip generation; evaluate existing images from this directory",
1155+
),
1156+
image_format: str = typer.Option(
1157+
"png", "--format", "-f", help="Output image format (png, jpeg, webp)"
1158+
),
1159+
seed: Optional[int] = typer.Option(None, "--seed", help="Random seed for reproducibility"),
1160+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed progress"),
1161+
):
1162+
"""Run generation + evaluation across PaperBananaBench entries."""
1163+
if image_format not in ("png", "jpeg", "webp"):
1164+
console.print(f"[red]Error: Format must be png, jpeg, or webp. Got: {image_format}[/red]")
1165+
raise typer.Exit(1)
1166+
1167+
configure_logging(verbose=verbose)
1168+
1169+
from dotenv import load_dotenv
1170+
1171+
load_dotenv()
1172+
1173+
overrides: dict = {"output_format": image_format}
1174+
if vlm_provider:
1175+
overrides["vlm_provider"] = vlm_provider
1176+
if vlm_model:
1177+
overrides["vlm_model"] = vlm_model
1178+
if image_provider:
1179+
overrides["image_provider"] = image_provider
1180+
if image_model:
1181+
overrides["image_model"] = image_model
1182+
if iterations is not None:
1183+
overrides["refinement_iterations"] = iterations
1184+
if auto:
1185+
overrides["auto_refine"] = True
1186+
if optimize:
1187+
overrides["optimize_inputs"] = True
1188+
if output_dir:
1189+
overrides["output_dir"] = output_dir
1190+
if seed is not None:
1191+
overrides["seed"] = seed
1192+
1193+
if config:
1194+
settings = Settings.from_yaml(config, **overrides)
1195+
else:
1196+
settings = Settings(**overrides)
1197+
1198+
from paperbanana.evaluation.benchmark import BenchmarkRunner
1199+
1200+
runner = BenchmarkRunner(settings)
1201+
1202+
# Load and filter entries
1203+
id_list = [s.strip() for s in ids.split(",") if s.strip()] if ids else None
1204+
try:
1205+
entries = runner.load_entries(category=category, ids=id_list, limit=limit)
1206+
except ValueError as e:
1207+
console.print(f"[red]Error: {e}[/red]")
1208+
raise typer.Exit(1)
1209+
1210+
if not entries:
1211+
console.print("[red]Error: No entries match the given filters.[/red]")
1212+
raise typer.Exit(1)
1213+
1214+
mode = "eval-only" if eval_only else "generate + evaluate"
1215+
console.print(
1216+
Panel.fit(
1217+
f"[bold]PaperBanana[/bold] — Benchmark\n\n"
1218+
f"Entries: {len(entries)}\n"
1219+
f"Mode: {mode}\n"
1220+
f"VLM: {settings.vlm_provider} / {settings.effective_vlm_model}\n"
1221+
f"Image: {settings.image_provider} / {settings.effective_image_model}",
1222+
border_style="magenta",
1223+
)
1224+
)
1225+
console.print()
1226+
1227+
bench_output_dir = Path(output_dir) if output_dir else None
1228+
1229+
async def _run():
1230+
return await runner.run(entries, output_dir=bench_output_dir, eval_only_dir=eval_only)
1231+
1232+
report = asyncio.run(_run())
1233+
summary = report.summary
1234+
1235+
if not summary:
1236+
console.print("[yellow]No entries were successfully evaluated.[/yellow]")
1237+
return
1238+
1239+
# Print summary table
1240+
console.print(
1241+
Panel.fit(
1242+
"[bold]Benchmark Summary[/bold]\n\n"
1243+
f"Evaluated: {summary.get('evaluated', 0)}\n"
1244+
f"Model wins: {summary.get('model_wins', 0)} "
1245+
f"Human wins: {summary.get('human_wins', 0)} "
1246+
f"Ties: {summary.get('ties', 0)}\n"
1247+
f"Model win rate: {summary.get('model_win_rate', 0)}%\n"
1248+
f"Mean overall score: {summary.get('mean_overall_score', 0)}/100\n"
1249+
f"Mean generation time: {summary.get('mean_generation_seconds', 0)}s\n\n"
1250+
f"Completed: {report.completed} "
1251+
f"Failed: {report.failed} "
1252+
f"Total: {report.total_seconds}s",
1253+
border_style="cyan",
1254+
)
1255+
)
1256+
1257+
# Per-dimension breakdown
1258+
dim_means = summary.get("dimension_means", {})
1259+
if dim_means:
1260+
console.print("\n[bold]Per-dimension scores:[/bold]")
1261+
for dim, score in dim_means.items():
1262+
console.print(f" {dim.capitalize():14s} {score}/100")
1263+
1264+
# Per-category breakdown
1265+
cat_breakdown = summary.get("category_breakdown", {})
1266+
if cat_breakdown:
1267+
console.print("\n[bold]Per-category breakdown:[/bold]")
1268+
for cat, stats in cat_breakdown.items():
1269+
console.print(
1270+
f" {cat:30s} n={stats['count']:3d} "
1271+
f"win_rate={stats['model_win_rate']:5.1f}% "
1272+
f"mean={stats['mean_score']:.1f}"
1273+
)
1274+
1275+
if report.run_dir:
1276+
report_path = Path(report.run_dir)
1277+
else:
1278+
report_path = Path(settings.output_dir) / report.created_at.replace(":", "")
1279+
console.print(f"\nReport: [bold]{report_path / 'benchmark_report.json'}[/bold]")
1280+
1281+
11271282
# ── Data subcommands ──────────────────────────────────────────────
11281283

11291284

0 commit comments

Comments
 (0)