Skip to content

Commit ff0c635

Browse files
authored
Merge branch 'main' into ci/lint-scripts
2 parents 841dc9d + 28a3a9d commit ff0c635

File tree

14 files changed

+1672
-15
lines changed

14 files changed

+1672
-15
lines changed

paperbanana/agents/planner.py

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
from __future__ import annotations
44

5+
import asyncio
6+
import ipaddress
57
import re
8+
import socket
9+
from io import BytesIO
610
from pathlib import Path
11+
from urllib.parse import urlparse
712

13+
import httpx
814
import structlog
15+
from PIL import Image
916

1017
from paperbanana.agents.base import BaseAgent
1118
from paperbanana.core.types import DiagramType, ReferenceExample
@@ -55,7 +62,7 @@ async def run(
5562
examples_text = self._format_examples(examples)
5663

5764
# Load reference images for visual in-context learning
58-
example_images = self._load_example_images(examples)
65+
example_images = await asyncio.to_thread(self._load_example_images, examples)
5966

6067
prompt_type = "diagram" if diagram_type == DiagramType.METHODOLOGY else "plot"
6168
template = self.load_prompt(prompt_type)
@@ -113,32 +120,116 @@ def _format_examples(self, examples: list[ReferenceExample]) -> str:
113120
if ex.aspect_ratio:
114121
ratio_info = f"\n**Aspect Ratio**: {ex.aspect_ratio:.2f}"
115122

123+
structure_info = ""
124+
if ex.structure_hints:
125+
hints_text = str(ex.structure_hints)
126+
structure_info = f"\n**Structure Hints**: {hints_text[:240]}"
127+
116128
lines.append(
117129
f"### Example {i}\n"
118130
f"**Caption**: {ex.caption}\n"
119131
f"**Source Context**: {ex.source_context[:500]}"
120132
f"{ratio_info}"
133+
f"{structure_info}"
121134
f"{image_ref}\n"
122135
)
123136
return "\n".join(lines)
124137

125138
def _has_valid_image(self, example: ReferenceExample) -> bool:
126-
"""Check if a reference example has a valid image file."""
127-
if not example.image_path:
139+
"""Check if a reference example has a loadable image (local path or http(s) URL)."""
140+
if not example.image_path or not example.image_path.strip():
128141
return False
129-
return Path(example.image_path).exists()
142+
path = example.image_path.strip()
143+
if self._is_remote_url(path):
144+
return self._is_safe_remote_image_url(path)
145+
return Path(path).exists()
146+
147+
@staticmethod
148+
def _is_remote_url(path: str) -> bool:
149+
return path.startswith(("http://", "https://"))
150+
151+
@classmethod
152+
def _is_safe_remote_image_url(cls, image_url: str) -> bool:
153+
parsed = urlparse(image_url)
154+
if parsed.scheme != "https":
155+
return False
156+
if not parsed.hostname:
157+
return False
158+
if parsed.username or parsed.password:
159+
return False
160+
161+
host = parsed.hostname.lower()
162+
if host in cls._LOCAL_HOSTNAMES or host.endswith(".local"):
163+
return False
164+
165+
try:
166+
ip = ipaddress.ip_address(host)
167+
except ValueError:
168+
return True
169+
return ip.is_global
170+
171+
@staticmethod
172+
def _hostname_resolves_to_global_addresses(hostname: str) -> bool:
173+
try:
174+
infos = socket.getaddrinfo(hostname, 443, type=socket.SOCK_STREAM)
175+
except socket.gaierror:
176+
return False
177+
if not infos:
178+
return False
179+
180+
for info in infos:
181+
address = info[4][0]
182+
try:
183+
ip = ipaddress.ip_address(address)
184+
except ValueError:
185+
return False
186+
if not ip.is_global:
187+
return False
188+
return True
189+
190+
def _fetch_remote_image(self, image_url: str) -> Image.Image:
191+
parsed = urlparse(image_url)
192+
hostname = parsed.hostname
193+
if not hostname:
194+
raise ValueError("remote image URL is missing hostname")
195+
if not self._hostname_resolves_to_global_addresses(hostname):
196+
raise ValueError("remote image hostname resolves to non-public address")
197+
198+
with httpx.Client(
199+
timeout=self._REMOTE_IMAGE_TIMEOUT_SECONDS,
200+
follow_redirects=False,
201+
) as client:
202+
response = client.get(image_url)
203+
if 300 <= response.status_code < 400:
204+
raise ValueError("remote image redirects are not allowed")
205+
response.raise_for_status()
206+
207+
content_type = (response.headers.get("content-type") or "").lower()
208+
if not content_type.startswith("image/"):
209+
raise ValueError("remote URL did not return an image content type")
210+
211+
data = response.content
212+
if len(data) > self._MAX_REMOTE_IMAGE_BYTES:
213+
raise ValueError(f"remote image exceeds {self._MAX_REMOTE_IMAGE_BYTES} byte limit")
214+
215+
return Image.open(BytesIO(data)).convert("RGB")
130216

131217
def _load_example_images(self, examples: list[ReferenceExample]) -> list:
132-
"""Load reference images from disk for in-context learning.
218+
"""Load reference images from disk or URL for in-context learning.
133219
134220
Returns a list of PIL Image objects for examples that have valid images.
221+
Supports local paths and http(s) URLs (e.g. from external exemplar adapters).
135222
"""
136223
images = []
137224
for ex in examples:
138225
if not self._has_valid_image(ex):
139226
continue
140227
try:
141-
img = load_image(ex.image_path)
228+
path = ex.image_path.strip()
229+
if self._is_remote_url(path):
230+
img = self._fetch_remote_image(path)
231+
else:
232+
img = load_image(path)
142233
images.append(img)
143234
except Exception as e:
144235
logger.warning(
@@ -168,3 +259,7 @@ def _parse_ratio(cls, text: str) -> tuple[str, str | None]:
168259
return clean, ratio
169260
logger.warning("Planner returned invalid ratio", ratio=ratio)
170261
return text.strip(), None
262+
263+
_REMOTE_IMAGE_TIMEOUT_SECONDS = 10.0
264+
_MAX_REMOTE_IMAGE_BYTES = 5 * 1024 * 1024
265+
_LOCAL_HOSTNAMES = {"localhost", "localhost.localdomain"}

paperbanana/cli.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,41 @@ def generate(
9292
"--auto-download-data",
9393
help="Auto-download expanded reference set (~257MB) on first run if not cached",
9494
),
95+
exemplar_retrieval: bool = typer.Option(
96+
False,
97+
"--exemplar-retrieval",
98+
help="Enable external exemplar retrieval before planning",
99+
),
100+
exemplar_endpoint: Optional[str] = typer.Option(
101+
None,
102+
"--exemplar-endpoint",
103+
help="External exemplar retrieval endpoint URL",
104+
),
105+
exemplar_mode: Optional[str] = typer.Option(
106+
None,
107+
"--exemplar-mode",
108+
help="Exemplar retrieval mode: external_then_rerank or external_only",
109+
),
110+
exemplar_top_k: Optional[int] = typer.Option(
111+
None,
112+
"--exemplar-top-k",
113+
help="Top-k exemplars requested from external retriever",
114+
),
115+
exemplar_timeout: Optional[float] = typer.Option(
116+
None,
117+
"--exemplar-timeout",
118+
help="External exemplar retrieval timeout (seconds)",
119+
),
120+
exemplar_retries: Optional[int] = typer.Option(
121+
None,
122+
"--exemplar-retries",
123+
help="Retry attempts for external exemplar retrieval on transient errors",
124+
),
125+
seed: Optional[int] = typer.Option(
126+
None,
127+
"--seed",
128+
help="Random seed for reproducible image generation",
129+
),
95130
verbose: bool = typer.Option(
96131
False, "--verbose", "-v", help="Show detailed agent progress and timing"
97132
),
@@ -104,6 +139,11 @@ def generate(
104139
if feedback and not continue_run and not continue_last:
105140
console.print("[red]Error: --feedback requires --continue or --continue-run[/red]")
106141
raise typer.Exit(1)
142+
if exemplar_mode and exemplar_mode not in ("external_then_rerank", "external_only"):
143+
console.print(
144+
"[red]Error: --exemplar-mode must be external_then_rerank or external_only[/red]"
145+
)
146+
raise typer.Exit(1)
107147

108148
configure_logging(verbose=verbose)
109149

@@ -128,6 +168,20 @@ def generate(
128168
if output:
129169
overrides["output_dir"] = str(Path(output).parent)
130170
overrides["output_format"] = format
171+
if exemplar_retrieval:
172+
overrides["exemplar_retrieval_enabled"] = True
173+
if exemplar_endpoint:
174+
overrides["exemplar_retrieval_endpoint"] = exemplar_endpoint
175+
if exemplar_mode:
176+
overrides["exemplar_retrieval_mode"] = exemplar_mode
177+
if exemplar_top_k is not None:
178+
overrides["exemplar_retrieval_top_k"] = exemplar_top_k
179+
if exemplar_timeout is not None:
180+
overrides["exemplar_retrieval_timeout_seconds"] = exemplar_timeout
181+
if exemplar_retries is not None:
182+
overrides["exemplar_retrieval_max_retries"] = exemplar_retries
183+
if seed is not None:
184+
overrides["seed"] = seed
131185

132186
if config:
133187
settings = Settings.from_yaml(config, **overrides)
@@ -615,6 +669,155 @@ async def _run():
615669
console.print(f"\n[bold]{dim}[/bold]: {result.reasoning}")
616670

617671

672+
@app.command("ablate-retrieval")
673+
def ablate_retrieval(
674+
input: str = typer.Option(..., "--input", "-i", help="Path to methodology text file"),
675+
caption: str = typer.Option(
676+
..., "--caption", "-c", help="Figure caption / communicative intent"
677+
),
678+
exemplar_endpoint: str = typer.Option(
679+
..., "--exemplar-endpoint", help="External exemplar retrieval endpoint URL"
680+
),
681+
top_k: str = typer.Option(
682+
"1,3,5", "--top-k", help="Comma-separated top-k values (e.g., 1,3,5)"
683+
),
684+
seed: Optional[int] = typer.Option(
685+
None,
686+
"--seed",
687+
help="Random seed used for all variants (default: 42 if omitted)",
688+
),
689+
exemplar_retries: Optional[int] = typer.Option(
690+
None,
691+
"--exemplar-retries",
692+
help="Retry attempts for external exemplar retrieval on transient errors",
693+
),
694+
reference: Optional[str] = typer.Option(
695+
None,
696+
"--reference",
697+
"-r",
698+
help="Optional human reference image for judge-based preference proxy",
699+
),
700+
output_report: Optional[str] = typer.Option(
701+
None,
702+
"--output-report",
703+
"-o",
704+
help="Output JSON report path (default: outputs/retrieval_ablation_<runid>.json)",
705+
),
706+
config: Optional[str] = typer.Option(None, "--config", help="Path to config YAML file"),
707+
vlm_provider: Optional[str] = typer.Option(
708+
None, "--vlm-provider", help="VLM provider override for generation and judge"
709+
),
710+
image_provider: Optional[str] = typer.Option(
711+
None, "--image-provider", help="Image generation provider override"
712+
),
713+
verbose: bool = typer.Option(
714+
False, "--verbose", "-v", help="Show detailed agent progress and timing"
715+
),
716+
):
717+
"""Run baseline vs retrieval ablation (k sweep) and save a JSON report."""
718+
configure_logging(verbose=verbose)
719+
720+
input_path = Path(input)
721+
if not input_path.exists():
722+
console.print(f"[red]Error: Input file not found: {input}[/red]")
723+
raise typer.Exit(1)
724+
725+
reference_path: Optional[Path] = None
726+
if reference:
727+
reference_path = Path(reference)
728+
if not reference_path.exists():
729+
console.print(f"[red]Error: Reference image not found: {reference}[/red]")
730+
raise typer.Exit(1)
731+
732+
from dotenv import load_dotenv
733+
734+
load_dotenv()
735+
736+
from paperbanana.core.types import DiagramType, GenerationInput
737+
from paperbanana.core.utils import generate_run_id
738+
from paperbanana.evaluation.retrieval_ablation import (
739+
RetrievalAblationRunner,
740+
parse_top_k_values,
741+
)
742+
743+
try:
744+
k_values = parse_top_k_values(top_k)
745+
except ValueError as e:
746+
console.print(f"[red]Error: {e}[/red]")
747+
raise typer.Exit(1)
748+
749+
overrides = {
750+
"exemplar_retrieval_endpoint": exemplar_endpoint,
751+
"exemplar_retrieval_enabled": True,
752+
}
753+
if vlm_provider:
754+
overrides["vlm_provider"] = vlm_provider
755+
if image_provider:
756+
overrides["image_provider"] = image_provider
757+
if seed is not None:
758+
overrides["seed"] = seed
759+
if exemplar_retries is not None:
760+
overrides["exemplar_retrieval_max_retries"] = exemplar_retries
761+
762+
if config:
763+
settings = Settings.from_yaml(config, **overrides)
764+
else:
765+
settings = Settings(**overrides)
766+
767+
gen_input = GenerationInput(
768+
source_context=input_path.read_text(encoding="utf-8"),
769+
communicative_intent=caption,
770+
diagram_type=DiagramType.METHODOLOGY,
771+
)
772+
773+
runner = RetrievalAblationRunner(
774+
settings,
775+
reference_image_path=str(reference_path) if reference_path else None,
776+
)
777+
778+
async def _run():
779+
return await runner.run(gen_input, top_k_values=k_values)
780+
781+
console.print(
782+
Panel.fit(
783+
f"[bold]PaperBanana[/bold] - Retrieval Ablation\n\n"
784+
f"Top-k sweep: {k_values}\n"
785+
f"Endpoint: {exemplar_endpoint}\n"
786+
f"Seed: {settings.seed if settings.seed is not None else 42}\n"
787+
f"Reference: {reference_path if reference_path else 'none'}",
788+
border_style="magenta",
789+
)
790+
)
791+
792+
report = asyncio.run(_run())
793+
794+
default_report_path = Path(settings.output_dir) / f"retrieval_ablation_{generate_run_id()}.json"
795+
report_path = Path(output_report) if output_report else default_report_path
796+
saved_path = runner.save_report(report, report_path)
797+
798+
summary = report.summary
799+
human_pref_line = ""
800+
if summary.get("best_human_preference_variant") is not None:
801+
human_pref_line = (
802+
f"Best human preference: {summary.get('best_human_preference_variant')} "
803+
f"({summary.get('best_human_preference_score')})\n"
804+
)
805+
console.print(
806+
Panel.fit(
807+
"[bold]Ablation Summary[/bold]\n\n"
808+
f"Best alignment: {summary.get('best_alignment_variant')} "
809+
f"({summary.get('best_alignment_score')})\n"
810+
f"{human_pref_line}"
811+
f"Fastest: {summary.get('fastest_variant')} "
812+
f"({summary.get('fastest_total_seconds')}s)\n"
813+
f"Fewest iterations: {summary.get('fewest_iterations_variant')} "
814+
f"({summary.get('fewest_iterations')})\n\n"
815+
f"Report: [bold]{saved_path}[/bold]",
816+
border_style="cyan",
817+
)
818+
)
819+
820+
618821
# ── Data subcommands ──────────────────────────────────────────────
619822

620823

0 commit comments

Comments
 (0)