Skip to content

Commit 90fb009

Browse files
authored
Merge pull request #80 from claytonlin1110/feat/batch-generation-manifest
feat: add batch generation from YAML/JSON manifest
2 parents 22deffc + 6319ac7 commit 90fb009

File tree

4 files changed

+309
-0
lines changed

4 files changed

+309
-0
lines changed

README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ An agentic framework for generating publication-quality academic diagrams and st
3939
- Input optimization layer for better generation quality
4040
- Auto-refine mode and run continuation with user feedback
4141
- CLI, Python API, and MCP server for IDE integration
42+
- **Batch generation** from a manifest file (YAML/JSON) for multiple diagrams in one run
4243
- Claude Code skills for `/generate-diagram`, `/generate-plot`, and `/evaluate-diagram`
4344

4445
<p align="center">
@@ -205,6 +206,39 @@ paperbanana plot \
205206
| `--output` | `-o` | Output image path |
206207
| `--iterations` | `-n` | Refinement iterations (default: 3) |
207208

209+
### `paperbanana batch` -- Batch Generation
210+
211+
Generate multiple methodology diagrams from a single manifest file (YAML or JSON). Each item runs the full pipeline; outputs are written under `outputs/batch_<id>/run_<id>/` and a `batch_report.json` summarizes all runs.
212+
213+
```bash
214+
paperbanana batch --manifest examples/batch_manifest.yaml --optimize
215+
```
216+
217+
Manifest format (YAML or JSON with an `items` list):
218+
219+
```yaml
220+
items:
221+
- input: path/to/method1.txt
222+
caption: "Overview of our encoder-decoder"
223+
id: fig1
224+
- input: method2.txt
225+
caption: "Training pipeline"
226+
id: fig2
227+
```
228+
229+
Paths in the manifest are resolved relative to the manifest file's directory.
230+
231+
| Flag | Short | Description |
232+
|------|-------|-------------|
233+
| `--manifest` | `-m` | Path to manifest file (required) |
234+
| `--output-dir` | `-o` | Parent directory for batch run (default: outputs) |
235+
| `--config` | | Path to config YAML |
236+
| `--iterations` | `-n` | Refinement iterations per item |
237+
| `--optimize` | | Preprocess inputs for each item |
238+
| `--auto` | | Loop until critic satisfied per item |
239+
| `--format` | `-f` | Output image format (png, jpeg, webp) |
240+
| `--auto-download-data` | | Download expanded reference set if needed |
241+
208242
### `paperbanana evaluate` -- Quality Assessment
209243

210244
Comparative evaluation of a generated diagram against a human reference using VLM-as-a-Judge:

examples/batch_manifest.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Batch manifest example: generate multiple methodology diagrams.
2+
# Paths are relative to this file's directory.
3+
# Run: paperbanana batch --manifest examples/batch_manifest.yaml
4+
5+
items:
6+
- input: sample_inputs/transformer_method.txt
7+
caption: "Overview of the Transformer encoder-decoder architecture with multi-head attention"
8+
id: transformer
9+
- input: sample_inputs/mamba_method.txt
10+
caption: "Mamba block with selective state space and gating"
11+
id: mamba

paperbanana/cli.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,189 @@ async def _critic_run(*a, **kw):
467467
console.print(f" Run ID: [dim]{result.metadata.get('run_id', 'unknown')}[/dim]")
468468

469469

470+
@app.command()
471+
def batch(
472+
manifest: str = typer.Option(
473+
..., "--manifest", "-m", help="Path to batch manifest (YAML or JSON)"
474+
),
475+
output_dir: str = typer.Option(
476+
"outputs",
477+
"--output-dir",
478+
"-o",
479+
help="Parent directory for batch run (batch_<id> will be created here)",
480+
),
481+
config: Optional[str] = typer.Option(None, "--config", help="Path to config YAML file"),
482+
vlm_provider: Optional[str] = typer.Option(None, "--vlm-provider", help="VLM provider"),
483+
vlm_model: Optional[str] = typer.Option(None, "--vlm-model", help="VLM model name"),
484+
image_provider: Optional[str] = typer.Option(
485+
None, "--image-provider", help="Image gen provider"
486+
),
487+
image_model: Optional[str] = typer.Option(None, "--image-model", help="Image gen model name"),
488+
iterations: Optional[int] = typer.Option(
489+
None, "--iterations", "-n", help="Refinement iterations"
490+
),
491+
auto: bool = typer.Option(
492+
False, "--auto", help="Loop until critic satisfied (with safety cap)"
493+
),
494+
max_iterations: Optional[int] = typer.Option(
495+
None, "--max-iterations", help="Safety cap for --auto"
496+
),
497+
optimize: bool = typer.Option(
498+
False, "--optimize", help="Preprocess inputs for better generation"
499+
),
500+
format: str = typer.Option(
501+
"png", "--format", "-f", help="Output image format (png, jpeg, webp)"
502+
),
503+
save_prompts: Optional[bool] = typer.Option(
504+
None, "--save-prompts/--no-save-prompts", help="Save prompts per run"
505+
),
506+
auto_download_data: bool = typer.Option(
507+
False, "--auto-download-data", help="Auto-download reference set if needed"
508+
),
509+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed progress"),
510+
):
511+
"""Generate multiple methodology diagrams from a manifest file (YAML or JSON)."""
512+
if format not in ("png", "jpeg", "webp"):
513+
console.print(f"[red]Error: Format must be png, jpeg, or webp. Got: {format}[/red]")
514+
raise typer.Exit(1)
515+
516+
configure_logging(verbose=verbose)
517+
manifest_path = Path(manifest)
518+
if not manifest_path.exists():
519+
console.print(f"[red]Error: Manifest not found: {manifest}[/red]")
520+
raise typer.Exit(1)
521+
522+
from paperbanana.core.batch import generate_batch_id, load_batch_manifest
523+
from paperbanana.core.utils import ensure_dir, save_json
524+
525+
try:
526+
items = load_batch_manifest(manifest_path)
527+
except (ValueError, FileNotFoundError, RuntimeError) as e:
528+
console.print(f"[red]Error loading manifest: {e}[/red]")
529+
raise typer.Exit(1)
530+
531+
batch_id = generate_batch_id()
532+
batch_dir = Path(output_dir) / batch_id
533+
ensure_dir(batch_dir)
534+
535+
overrides = {"output_dir": str(batch_dir), "output_format": format}
536+
if vlm_provider:
537+
overrides["vlm_provider"] = vlm_provider
538+
if vlm_model:
539+
overrides["vlm_model"] = vlm_model
540+
if image_provider:
541+
overrides["image_provider"] = image_provider
542+
if image_model:
543+
overrides["image_model"] = image_model
544+
if iterations is not None:
545+
overrides["refinement_iterations"] = iterations
546+
if auto:
547+
overrides["auto_refine"] = True
548+
if max_iterations is not None:
549+
overrides["max_iterations"] = max_iterations
550+
if optimize:
551+
overrides["optimize_inputs"] = True
552+
if save_prompts is not None:
553+
overrides["save_prompts"] = save_prompts
554+
555+
if config:
556+
settings = Settings.from_yaml(config, **overrides)
557+
else:
558+
from dotenv import load_dotenv
559+
560+
load_dotenv()
561+
settings = Settings(**overrides)
562+
563+
if auto_download_data:
564+
from paperbanana.data.manager import DatasetManager
565+
566+
dm = DatasetManager(cache_dir=settings.cache_dir)
567+
if not dm.is_downloaded():
568+
console.print(" [dim]Downloading expanded reference set...[/dim]")
569+
try:
570+
dm.download()
571+
except Exception as e:
572+
console.print(f" [yellow]Download failed: {e}, using built-in set[/yellow]")
573+
574+
console.print(
575+
Panel.fit(
576+
f"[bold]PaperBanana[/bold] — Batch Generation\n\n"
577+
f"Manifest: {manifest_path.name}\n"
578+
f"Items: {len(items)}\n"
579+
f"Output: {batch_dir}",
580+
border_style="blue",
581+
)
582+
)
583+
console.print()
584+
585+
from paperbanana.core.pipeline import PaperBananaPipeline
586+
587+
report = {"batch_id": batch_id, "manifest": str(manifest_path), "items": []}
588+
total_start = time.perf_counter()
589+
590+
for idx, item in enumerate(items):
591+
item_id = item["id"]
592+
input_path = Path(item["input"])
593+
if not input_path.exists():
594+
console.print(f"[red]Skipping item '{item_id}': input not found: {input_path}[/red]")
595+
report["items"].append(
596+
{
597+
"id": item_id,
598+
"input": item["input"],
599+
"caption": item["caption"],
600+
"run_id": None,
601+
"output_path": None,
602+
"error": "input file not found",
603+
}
604+
)
605+
continue
606+
source_context = input_path.read_text(encoding="utf-8")
607+
gen_input = GenerationInput(
608+
source_context=source_context,
609+
communicative_intent=item["caption"],
610+
diagram_type=DiagramType.METHODOLOGY,
611+
)
612+
console.print(f"[bold]Item {idx + 1}/{len(items)}[/bold] — {item_id}")
613+
pipeline = PaperBananaPipeline(settings=settings)
614+
try:
615+
result = asyncio.run(pipeline.generate(gen_input))
616+
report["items"].append(
617+
{
618+
"id": item_id,
619+
"input": item["input"],
620+
"caption": item["caption"],
621+
"run_id": result.metadata.get("run_id"),
622+
"output_path": result.image_path,
623+
"iterations": len(result.iterations),
624+
}
625+
)
626+
console.print(f" [green]✓[/green] [dim]{result.image_path}[/dim]\n")
627+
except Exception as e:
628+
console.print(f" [red]✗[/red] {e}\n")
629+
report["items"].append(
630+
{
631+
"id": item_id,
632+
"input": item["input"],
633+
"caption": item["caption"],
634+
"run_id": None,
635+
"output_path": None,
636+
"error": str(e),
637+
}
638+
)
639+
640+
total_elapsed = time.perf_counter() - total_start
641+
report["total_seconds"] = round(total_elapsed, 1)
642+
report_path = batch_dir / "batch_report.json"
643+
save_json(report, report_path)
644+
645+
succeeded = sum(1 for x in report["items"] if x.get("output_path"))
646+
console.print(
647+
f"[green]Batch complete.[/green] [dim]{total_elapsed:.1f}s · "
648+
f"{succeeded}/{len(items)} succeeded[/dim]"
649+
)
650+
console.print(f" Report: [bold]{report_path}[/bold]")
651+
652+
470653
@app.command()
471654
def plot(
472655
data: str = typer.Option(..., "--data", "-d", help="Path to data file (CSV or JSON)"),

paperbanana/core/batch.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Batch generation: manifest loading and batch run id."""
2+
3+
from __future__ import annotations
4+
5+
import datetime
6+
import uuid
7+
from pathlib import Path
8+
from typing import Any
9+
10+
import structlog
11+
12+
logger = structlog.get_logger()
13+
14+
15+
def generate_batch_id() -> str:
16+
"""Generate a unique batch run ID."""
17+
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
18+
short_uuid = uuid.uuid4().hex[:6]
19+
return f"batch_{ts}_{short_uuid}"
20+
21+
22+
def load_batch_manifest(manifest_path: Path) -> list[dict[str, Any]]:
23+
"""Load a batch manifest (YAML or JSON) and return a list of items.
24+
25+
Each item is a dict with:
26+
- input: path to methodology text file (resolved relative to manifest parent)
27+
- caption: figure caption / communicative intent
28+
- id: optional string identifier for the item (default: index-based)
29+
30+
Paths in the manifest are resolved relative to the manifest file's directory.
31+
"""
32+
manifest_path = Path(manifest_path).resolve()
33+
if not manifest_path.exists():
34+
raise FileNotFoundError(f"Manifest not found: {manifest_path}")
35+
parent = manifest_path.parent
36+
raw = manifest_path.read_text(encoding="utf-8")
37+
suffix = manifest_path.suffix.lower()
38+
if suffix in (".yaml", ".yml"):
39+
try:
40+
import yaml
41+
42+
data = yaml.safe_load(raw)
43+
except ImportError:
44+
raise RuntimeError(
45+
"PyYAML is required for YAML manifests. Install with: pip install pyyaml"
46+
)
47+
elif suffix == ".json":
48+
import json
49+
50+
data = json.loads(raw)
51+
else:
52+
raise ValueError(f"Manifest must be .yaml, .yml, or .json. Got: {manifest_path.suffix}")
53+
54+
if data is None:
55+
raise ValueError("Manifest is empty")
56+
if isinstance(data, list):
57+
items = data
58+
elif isinstance(data, dict) and "items" in data:
59+
items = data["items"]
60+
else:
61+
raise ValueError("Manifest must be a list of items or an object with an 'items' list")
62+
63+
result = []
64+
for i, entry in enumerate(items):
65+
if not isinstance(entry, dict):
66+
raise ValueError(f"Manifest item {i} must be an object, got {type(entry).__name__}")
67+
inp = entry.get("input")
68+
caption = entry.get("caption")
69+
if not inp or not caption:
70+
raise ValueError(f"Manifest item {i}: 'input' and 'caption' are required")
71+
input_path = Path(inp)
72+
if not input_path.is_absolute():
73+
input_path = (parent / input_path).resolve()
74+
result.append(
75+
{
76+
"input": str(input_path),
77+
"caption": str(caption),
78+
"id": entry.get("id", f"item_{i + 1}"),
79+
}
80+
)
81+
return result

0 commit comments

Comments
 (0)