Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ GOOGLE_BASE_URL=
GOOGLE_VLM_MODEL=gemini-2.0-flash
GOOGLE_IMAGE_MODEL=gemini-3-pro-image-preview

# ── Local open-weight models ──────────────────────────────────────
# Ollama (local): https://ollama.com
# To use Ollama, set VLM_PROVIDER=ollama and pull a vision model:
# ollama pull qwen2.5-vl
# OLLAMA_BASE_URL=http://localhost:11434/v1
# OLLAMA_MODEL=qwen2.5-vl
# OLLAMA_JSON_MODE=false

# vLLM / llama.cpp (OpenAI-compatible local server):
# To use vLLM, set VLM_PROVIDER=openai_local and start the server:
# vllm serve Qwen/Qwen2.5-VL-7B-Instruct
# OPENAI_BASE_URL=http://localhost:8000/v1
# OPENAI_LOCAL_JSON_MODE=false

# ── SSL ────────────────────────────────────────────────────────────
# Set to true to skip SSL certificate verification (corporate proxies)
SKIP_SSL_VERIFICATION=false
43 changes: 25 additions & 18 deletions paperbanana/agents/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from __future__ import annotations

import json
import re
from typing import Optional

import structlog

from paperbanana.agents.base import BaseAgent
from paperbanana.core.types import CritiqueResult, DiagramType
from paperbanana.core.utils import load_image
from paperbanana.core.utils import extract_json, load_image
from paperbanana.providers.base import VLMProvider

logger = structlog.get_logger()
Expand Down Expand Up @@ -83,14 +82,15 @@ async def run(
except Exception:
logger.warning("Prompt recording failed", agent=self.agent_name, label=prompt_label)

logger.info("Running critic agent", image_path=image_path)
use_json = getattr(self.vlm, "supports_json_mode", True)
logger.info("Running critic agent", image_path=image_path, json_mode=use_json)

response = await self.vlm.generate(
prompt=prompt,
images=[image],
temperature=0.3,
max_tokens=4096,
response_format="json",
response_format="json" if use_json else None,
)

critique = self._parse_response(response)
Expand All @@ -110,17 +110,24 @@ def _prompt_label_from_image_path(image_path: str) -> str | None:
return f"critic_iter_{m.group(1)}"

def _parse_response(self, response: str) -> CritiqueResult:
"""Parse the VLM response into a CritiqueResult."""
try:
data = json.loads(response)
return CritiqueResult(
critic_suggestions=data.get("critic_suggestions", []),
revised_description=data.get("revised_description"),
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning("Failed to parse critic response", error=str(e))
# Conservative fallback: empty suggestions means no revision needed
return CritiqueResult(
critic_suggestions=[],
revised_description=None,
)
"""Parse the VLM response into a CritiqueResult.

Uses extract_json for robust parsing — handles markdown fences
and conversational text that open-weight models often produce.
"""
data = extract_json(response)
if isinstance(data, dict):
try:
return CritiqueResult(
critic_suggestions=data.get("critic_suggestions", []),
revised_description=data.get("revised_description"),
)
except (KeyError, TypeError) as e:
logger.warning("Failed to build CritiqueResult from parsed JSON", error=str(e))

logger.warning("Failed to parse critic response as JSON")
# Conservative fallback: empty suggestions means no revision needed
return CritiqueResult(
critic_suggestions=[],
revised_description=None,
)
25 changes: 11 additions & 14 deletions paperbanana/agents/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from __future__ import annotations

import json

import structlog

from paperbanana.agents.base import BaseAgent
from paperbanana.core.types import DiagramType, ReferenceExample
from paperbanana.core.utils import extract_json
from paperbanana.providers.base import VLMProvider

logger = structlog.get_logger()
Expand Down Expand Up @@ -77,16 +76,18 @@ async def run(
num_examples=num_examples,
)

# Call the VLM
# Call the VLM — only request JSON mode if the provider supports it.
use_json = getattr(self.vlm, "supports_json_mode", True)
logger.info(
"Running retriever agent",
num_candidates=len(candidates),
num_requested=num_examples,
json_mode=use_json,
)
response = await self.vlm.generate(
prompt=prompt,
temperature=0.3, # Low temperature for consistent selection
response_format="json",
response_format="json" if use_json else None,
)

# Parse response
Expand Down Expand Up @@ -117,19 +118,15 @@ def _parse_response(
Handles both 'selected_ids' (our format) and 'top_10_papers'/'top_10_plots'
(paper's format) JSON keys for robustness.
"""
try:
data = json.loads(response)
selected_ids = (
data.get("selected_ids")
or data.get("top_10_papers")
or data.get("top_10_plots")
or []
)
except json.JSONDecodeError:
data = extract_json(response)
if not isinstance(data, dict):
logger.warning("Failed to parse retriever response as JSON, using fallback")
# Fallback: return first N candidates
return candidates

selected_ids = (
data.get("selected_ids") or data.get("top_10_papers") or data.get("top_10_plots") or []
)

# Map IDs back to ReferenceExample objects
id_to_example = {c.id: c for c in candidates}
selected = []
Expand Down
8 changes: 8 additions & 0 deletions paperbanana/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ class Settings(BaseSettings):
openai_vlm_model: Optional[str] = Field(default=None, alias="OPENAI_VLM_MODEL")
openai_image_model: Optional[str] = Field(default=None, alias="OPENAI_IMAGE_MODEL")

# Ollama settings (local open-weight models)
ollama_base_url: str = Field(default="http://localhost:11434/v1", alias="OLLAMA_BASE_URL")
ollama_model: Optional[str] = Field(default=None, alias="OLLAMA_MODEL")
ollama_json_mode: bool = Field(default=False, alias="OLLAMA_JSON_MODE")

# OpenAI-local settings (vLLM, llama.cpp, etc. behind OpenAI-compatible API)
openai_local_json_mode: bool = Field(default=False, alias="OPENAI_LOCAL_JSON_MODE")

# AWS Bedrock settings
aws_region: str = Field(default="us-east-1", alias="AWS_REGION")
aws_profile: Optional[str] = Field(default=None, alias="AWS_PROFILE")
Expand Down
73 changes: 73 additions & 0 deletions paperbanana/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import hashlib
import json
import re
import uuid
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -171,6 +172,78 @@ def detect_image_mime_type(path: str | Path) -> str:
return mime or "application/octet-stream"


def extract_json(text: str) -> dict | list | None:
"""Best-effort JSON extraction from free-form VLM output.

Many open-weight models don't support structured JSON mode, so they
wrap JSON in markdown fences or add conversational text around it.
This tries, in order:
1. Direct json.loads on the full text
2. Extract from ```json ... ``` fenced blocks
3. Extract from bare ``` ... ``` fenced blocks
4. Find the first { ... } or [ ... ] substring via bracket matching

Returns the parsed object, or None if nothing worked.
"""
text = text.strip()

# 1) Maybe the entire response is valid JSON already.
try:
return json.loads(text)
except (json.JSONDecodeError, ValueError):
pass

# 2) Try ```json ... ``` fenced block.
m = re.search(r"```json\s*\n(.*?)```", text, re.DOTALL)
if m:
try:
return json.loads(m.group(1).strip())
except (json.JSONDecodeError, ValueError):
pass

# 3) Try bare ``` ... ``` fenced block.
m = re.search(r"```\s*\n(.*?)```", text, re.DOTALL)
if m:
try:
return json.loads(m.group(1).strip())
except (json.JSONDecodeError, ValueError):
pass

# 4) Find the outermost { ... } or [ ... ] substring.
for open_ch, close_ch in [("{", "}"), ("[", "]")]:
start = text.find(open_ch)
if start == -1:
continue
depth = 0
in_string = False
escape_next = False
for i in range(start, len(text)):
ch = text[i]
if escape_next:
escape_next = False
continue
if ch == "\\":
escape_next = True
continue
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == open_ch:
depth += 1
elif ch == close_ch:
depth -= 1
if depth == 0:
try:
return json.loads(text[start : i + 1])
except (json.JSONDecodeError, ValueError):
break
# depth never reached zero or parse failed — fall through

return None


def find_prompt_dir() -> str:
"""Locate the prompts directory, handling CWD != project root.

Expand Down
40 changes: 21 additions & 19 deletions paperbanana/evaluation/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import json
from pathlib import Path
from typing import Optional

Expand All @@ -14,7 +13,7 @@
DimensionResult,
EvaluationScore,
)
from paperbanana.core.utils import load_image
from paperbanana.core.utils import extract_json, load_image
from paperbanana.providers.base import VLMProvider

logger = structlog.get_logger()
Expand Down Expand Up @@ -66,8 +65,9 @@ async def evaluate(

results: dict[str, DimensionResult] = {}

use_json = getattr(self.vlm, "supports_json_mode", True)
for dim in DIMENSIONS:
logger.info("Evaluating dimension", dimension=dim)
logger.info("Evaluating dimension", dimension=dim, json_mode=use_json)

prompt = self._load_eval_prompt(dim, source_context, caption)

Expand All @@ -76,7 +76,7 @@ async def evaluate(
images=images,
temperature=0.1,
max_tokens=1024,
response_format="json",
response_format="json" if use_json else None,
)

results[dim] = self._parse_result(response, dim)
Expand Down Expand Up @@ -104,13 +104,16 @@ def _load_eval_prompt(self, dimension: str, source_context: str, caption: str) -
return template.format(source_context=source_context, caption=caption)

def _parse_result(self, response: str, dimension: str) -> DimensionResult:
"""Parse a comparative result from VLM response."""
try:
data = json.loads(response)
"""Parse a comparative result from VLM response.

Uses extract_json for robust parsing so open-weight models that
wrap JSON in markdown or conversational text still work.
"""
data = extract_json(response)
if isinstance(data, dict):
winner = data.get("winner", "Both are good")
reasoning = data.get("comparison_reasoning", "")

# Validate winner value
if winner not in VALID_WINNERS:
logger.warning(
"Invalid winner value, defaulting to tie",
Expand All @@ -121,17 +124,16 @@ def _parse_result(self, response: str, dimension: str) -> DimensionResult:

score = WINNER_SCORE_MAP.get(winner, 50.0)
return DimensionResult(winner=winner, score=score, reasoning=reasoning)
except (json.JSONDecodeError, ValueError, TypeError) as e:
logger.warning(
"Failed to parse evaluation response",
dimension=dimension,
error=str(e),
)
return DimensionResult(
winner="Both are good",
score=50.0,
reasoning="Could not parse evaluation response.",
)

logger.warning(
"Failed to parse evaluation response",
dimension=dimension,
)
return DimensionResult(
winner="Both are good",
score=50.0,
reasoning="Could not parse evaluation response.",
)

def _hierarchical_aggregate(self, results: dict[str, DimensionResult]) -> str:
"""Apply hierarchical aggregation per paper Section 4.2.
Expand Down
10 changes: 10 additions & 0 deletions paperbanana/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ async def generate(
"""
...

@property
def supports_json_mode(self) -> bool:
"""Whether this provider reliably handles response_format='json'.

Hosted APIs (OpenAI, Gemini, Anthropic) support this natively.
Local/open-weight backends often don't — override to return False
so agents fall back to prompt-based JSON and robust parsing.
"""
return True

def is_available(self) -> bool:
"""Check if this provider is configured and available."""
return True
Expand Down
24 changes: 23 additions & 1 deletion paperbanana/providers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,32 @@ def create_vlm(settings: Settings) -> VLMProvider:
api_key=settings.anthropic_api_key,
model=settings.vlm_model,
)
elif provider == "ollama":
from paperbanana.providers.vlm.ollama import OllamaVLM

return OllamaVLM(
model=settings.ollama_model or settings.vlm_model,
base_url=settings.ollama_base_url,
json_mode=settings.ollama_json_mode,
)
elif provider == "openai_local":
# OpenAI-compatible local server (vLLM, llama.cpp, etc.)
# Uses the OpenAI SDK but skips API key validation and
# disables JSON mode by default.
from paperbanana.providers.vlm.openai import OpenAIVLM

return OpenAIVLM(
api_key=settings.openai_api_key or "not-needed",
model=settings.openai_vlm_model or settings.vlm_model,
base_url=settings.openai_base_url,
json_mode=settings.openai_local_json_mode,
provider_name="openai_local",
)
else:
raise ValueError(
"Unknown VLM provider: "
f"{provider}. Available: gemini, openrouter, openai, bedrock, anthropic"
f"{provider}. Available: gemini, openrouter, openai, openai_local, "
f"bedrock, anthropic, ollama"
)

@staticmethod
Expand Down
Loading
Loading