Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ An agentic framework for generating publication-quality academic diagrams and st
- Auto-refine mode and run continuation with user feedback
- CLI, Python API, and MCP server for IDE integration
- **Batch generation** from a manifest file (YAML/JSON) for multiple diagrams in one run
- **Batch plots** — `paperbanana plot-batch` runs many statistical plots from one manifest (CSV/JSON per item)
- **PDF inputs** for methodology context (optional `paperbanana[pdf]` / PyMuPDF), with per-page selection
- **PaperBanana Studio** — local Gradio web UI (`paperbanana studio`) for diagrams, plots, evaluation, batch, and run browser
- Claude Code skills for `/generate-diagram`, `/generate-plot`, and `/evaluate-diagram`
Expand Down Expand Up @@ -123,7 +124,7 @@ pip install 'paperbanana[studio]'
paperbanana studio
```

Open the URL shown in the terminal (default `http://127.0.0.1:7860/`). The Studio exposes the same workflows as the CLI: methodology diagrams, statistical plots, comparative evaluation, continuing a prior run, batch manifests, and a simple browser for `run_*` / `batch_*` output folders. Use `--host`, `--port`, `--config`, and `--output-dir` as needed.
Open the URL shown in the terminal (default `http://127.0.0.1:7860/`). The Studio exposes the same workflows as the CLI: methodology diagrams, statistical plots, comparative evaluation, continuing a prior run, batch manifests (methodology or **plot** batch via the Batch tab), and a simple browser for `run_*` / `batch_*` output folders. Use `--host`, `--port`, `--config`, and `--output-dir` as needed.

---

Expand Down Expand Up @@ -268,6 +269,8 @@ paperbanana batch-report --batch-dir outputs/batch_20250109_123456_abc --format
paperbanana batch-report --batch-id batch_20250109_123456_abc --format html --output report.html
```

Diagram batch reports include `batch_kind: methodology`; plot batches use `batch_kind: statistical_plot`. Human-readable reports (`paperbanana batch-report`) show the batch kind when present.

| Flag | Short | Description |
|------|-------|-------------|
| `--manifest` | `-m` | Path to manifest file (required) |
Expand All @@ -279,6 +282,47 @@ paperbanana batch-report --batch-id batch_20250109_123456_abc --format html --ou
| `--format` | `-f` | Output image format (png, jpeg, webp) |
| `--auto-download-data` | | Download expanded reference set if needed |

### `paperbanana plot-batch` -- Batch Statistical Plots

Generate multiple plots from a manifest (YAML or JSON). Each item specifies a **data** file (CSV or JSON) and an **intent** string, mirroring `paperbanana plot`. Outputs live under `outputs/batch_<id>/run_<id>/` with the same `batch_report.json` and `paperbanana batch-report` workflow as diagram batches.

```bash
paperbanana plot-batch --manifest examples/plot_batch_manifest.yaml --optimize
```

Manifest format (`items` list):

```yaml
items:
- data: path/to/results.csv
intent: "Bar chart comparing accuracy across models"
id: fig_acc
- data: other.json
intent: "Scatter plot with trend line"
aspect_ratio: "16:9" # optional per item; CLI --aspect-ratio is the default when omitted
```

Paths are resolved relative to the manifest file’s directory.

| Flag | Short | Description |
|------|-------|-------------|
| `--manifest` | `-m` | Path to manifest (required) |
| `--output-dir` | `-o` | Parent directory for `batch_*` (default: outputs) |
| `--config` | | Path to config YAML |
| `--vlm-provider` | | VLM provider (default: gemini) |
| `--vlm-model` | | VLM model override |
| `--image-provider` | | Image gen provider |
| `--image-model` | | Image gen model |
| `--iterations` | `-n` | Refinement iterations per item |
| `--auto` | | Loop until critic satisfied per item |
| `--max-iterations` | | Safety cap for `--auto` |
| `--optimize` | | Input optimization per item |
| `--format` | `-f` | png, jpeg, or webp |
| `--save-prompts` / `--no-save-prompts` | | Persist prompts (default: on, same as `plot`) |
| `--venue` | | Venue style (neurips, icml, acl, ieee, custom) |
| `--aspect-ratio` | `-ar` | Default aspect ratio when not set in the manifest |
| `--verbose` | `-v` | Verbose logging |

### `paperbanana evaluate` -- Quality Assessment

Comparative evaluation of a generated diagram against a human reference using VLM-as-a-Judge:
Expand Down
10 changes: 10 additions & 0 deletions examples/plot_batch_manifest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Statistical plot batch — same idea as `paperbanana batch` but each item is data + intent.
# Run: paperbanana plot-batch --manifest examples/plot_batch_manifest.yaml
items:
- data: sample_data/benchmark_slice.csv
intent: "Bar chart of accuracy by model; clear axis labels and NeurIPS-style colors."
id: acc_bar
- data: sample_data/benchmark_slice.csv
intent: "Side-by-side bars comparing accuracy and F1 for each model."
id: acc_f1_grouped
aspect_ratio: "16:9"
5 changes: 5 additions & 0 deletions examples/sample_data/benchmark_slice.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model,accuracy,f1
GPT-4o,0.91,0.89
Claude-3,0.88,0.86
Gemini,0.85,0.83
Llama-3,0.79,0.77
246 changes: 230 additions & 16 deletions paperbanana/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,12 @@ def batch(

from paperbanana.core.pipeline import PaperBananaPipeline

report = {"batch_id": batch_id, "manifest": str(manifest_path), "items": []}
report = {
"batch_id": batch_id,
"manifest": str(manifest_path),
"batch_kind": "methodology",
"items": [],
}
total_start = time.perf_counter()

for idx, item in enumerate(items):
Expand Down Expand Up @@ -846,6 +851,224 @@ def batch_report(
raise typer.Exit(1)


@app.command("plot-batch")
def plot_batch(
manifest: str = typer.Option(
..., "--manifest", "-m", help="Path to plot batch manifest (YAML or JSON)"
),
output_dir: str = typer.Option(
"outputs",
"--output-dir",
"-o",
help="Parent directory for batch run (batch_<id> will be created here)",
),
config: Optional[str] = typer.Option(None, "--config", help="Path to config YAML file"),
vlm_provider: Optional[str] = typer.Option(
None, "--vlm-provider", help="VLM provider (default: gemini)"
),
vlm_model: Optional[str] = typer.Option(None, "--vlm-model", help="VLM model name"),
image_provider: Optional[str] = typer.Option(
None, "--image-provider", help="Image gen provider"
),
image_model: Optional[str] = typer.Option(None, "--image-model", help="Image gen model name"),
iterations: Optional[int] = typer.Option(
None, "--iterations", "-n", help="Refinement iterations per plot"
),
auto: bool = typer.Option(
False, "--auto", help="Loop until critic satisfied per item (with safety cap)"
),
max_iterations: Optional[int] = typer.Option(
None, "--max-iterations", help="Safety cap for --auto"
),
optimize: bool = typer.Option(
False, "--optimize", help="Preprocess inputs per item (enrich context, sharpen intent)"
),
format: str = typer.Option(
"png", "--format", "-f", help="Output image format (png, jpeg, webp)"
),
save_prompts: Optional[bool] = typer.Option(
None,
"--save-prompts/--no-save-prompts",
help="Save prompts per run",
),
venue: Optional[str] = typer.Option(
None,
"--venue",
help="Target venue style (neurips, icml, acl, ieee, custom)",
),
aspect_ratio: Optional[str] = typer.Option(
None,
"--aspect-ratio",
"-ar",
help="Default aspect ratio when not set per manifest item",
),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed progress"),
):
"""Generate multiple statistical plots from a manifest (data + intent per item)."""
if format not in ("png", "jpeg", "webp"):
console.print(f"[red]Error: Format must be png, jpeg, or webp. Got: {format}[/red]")
raise typer.Exit(1)
if venue and venue.lower() not in ("neurips", "icml", "acl", "ieee", "custom"):
console.print(
f"[red]Error: --venue must be neurips, icml, acl, ieee, or custom. Got: {venue}[/red]"
)
raise typer.Exit(1)

configure_logging(verbose=verbose)
manifest_path = Path(manifest)
if not manifest_path.exists():
console.print(f"[red]Error: Manifest not found: {manifest}[/red]")
raise typer.Exit(1)

from paperbanana.core.batch import generate_batch_id, load_plot_batch_manifest
from paperbanana.core.plot_data import load_statistical_plot_payload
from paperbanana.core.utils import ensure_dir, save_json

try:
items = load_plot_batch_manifest(manifest_path)
except (ValueError, FileNotFoundError, RuntimeError) as e:
console.print(f"[red]Error loading manifest: {e}[/red]")
raise typer.Exit(1)

batch_id = generate_batch_id()
batch_dir = Path(output_dir) / batch_id
ensure_dir(batch_dir)

overrides: dict = {
"output_dir": str(batch_dir),
"output_format": format,
"optimize_inputs": optimize,
"auto_refine": auto,
}
if vlm_provider:
overrides["vlm_provider"] = vlm_provider
if vlm_model:
overrides["vlm_model"] = vlm_model
if image_provider:
overrides["image_provider"] = image_provider
if image_model:
overrides["image_model"] = image_model
if iterations is not None:
overrides["refinement_iterations"] = iterations
if max_iterations is not None:
overrides["max_iterations"] = max_iterations
overrides["save_prompts"] = True if save_prompts is None else save_prompts
if venue:
overrides["venue"] = venue
if not vlm_provider:
overrides.setdefault("vlm_provider", "gemini")

if config:
settings = Settings.from_yaml(config, **overrides)
else:
from dotenv import load_dotenv

load_dotenv()
settings = Settings(**overrides)

console.print(
Panel.fit(
f"[bold]PaperBanana[/bold] — Batch Plot Generation\n\n"
f"Manifest: {manifest_path.name}\n"
f"Items: {len(items)}\n"
f"Output: {batch_dir}",
border_style="green",
)
)
console.print()

from paperbanana.core.pipeline import PaperBananaPipeline

report: dict = {
"batch_id": batch_id,
"manifest": str(manifest_path),
"batch_kind": "statistical_plot",
"items": [],
}
total_start = time.perf_counter()

for idx, item in enumerate(items):
item_id = item["id"]
data_path = Path(item["data"])
if not data_path.exists():
console.print(f"[red]Skipping item '{item_id}': data file not found: {data_path}[/red]")
report["items"].append(
{
"id": item_id,
"data": item["data"],
"caption": item["intent"],
"run_id": None,
"output_path": None,
"error": "data file not found",
}
)
continue

try:
source_context, raw_data = load_statistical_plot_payload(data_path)
except (ValueError, OSError) as e:
console.print(f"[red]Skipping item '{item_id}': {e}[/red]")
report["items"].append(
{
"id": item_id,
"data": item["data"],
"caption": item["intent"],
"run_id": None,
"output_path": None,
"error": str(e),
}
)
continue

ar = item.get("aspect_ratio") or aspect_ratio
gen_input = GenerationInput(
source_context=source_context,
communicative_intent=item["intent"],
diagram_type=DiagramType.STATISTICAL_PLOT,
raw_data={"data": raw_data},
aspect_ratio=ar,
)
console.print(f"[bold]Item {idx + 1}/{len(items)}[/bold] — {item_id}")
pipeline = PaperBananaPipeline(settings=settings)
try:
result = asyncio.run(pipeline.generate(gen_input))
report["items"].append(
{
"id": item_id,
"data": item["data"],
"caption": item["intent"],
"run_id": result.metadata.get("run_id"),
"output_path": result.image_path,
"iterations": len(result.iterations),
}
)
console.print(f" [green]✓[/green] [dim]{result.image_path}[/dim]\n")
except Exception as e:
console.print(f" [red]✗[/red] {e}\n")
report["items"].append(
{
"id": item_id,
"data": item["data"],
"caption": item["intent"],
"run_id": None,
"output_path": None,
"error": str(e),
}
)

total_elapsed = time.perf_counter() - total_start
report["total_seconds"] = round(total_elapsed, 1)
report_path = batch_dir / "batch_report.json"
save_json(report, report_path)

succeeded = sum(1 for x in report["items"] if x.get("output_path"))
console.print(
f"[green]Plot batch complete.[/green] [dim]{total_elapsed:.1f}s · "
f"{succeeded}/{len(items)} succeeded[/dim]"
)
console.print(f" Report: [bold]{report_path}[/bold]")


@app.command()
def plot(
data: str = typer.Option(..., "--data", "-d", help="Path to data file (CSV or JSON)"),
Expand Down Expand Up @@ -901,22 +1124,13 @@ def plot(
console.print(f"[red]Error: Data file not found: {data}[/red]")
raise typer.Exit(1)

# Load data
import json as json_mod

if data_path.suffix == ".csv":
import pandas as pd
from paperbanana.core.plot_data import load_statistical_plot_payload

df = pd.read_csv(data_path)
raw_data = df.to_dict(orient="records")
source_context = (
f"CSV data with columns: {list(df.columns)}\n"
f"Rows: {len(df)}\nSample:\n{df.head().to_string()}"
)
else:
with open(data_path) as f:
raw_data = json_mod.load(f)
source_context = f"JSON data:\n{json_mod.dumps(raw_data, indent=2)[:2000]}"
try:
source_context, raw_data = load_statistical_plot_payload(data_path)
except (FileNotFoundError, ValueError) as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1)

from dotenv import load_dotenv

Expand Down
Loading
Loading