Skip to content

Commit 22deffc

Browse files
authored
Merge pull request #75 from claytonlin1110/feat/save-prompts-run-artifacts
feat: persist formatted agent prompts in run outputs
2 parents 261d680 + 2e481f2 commit 22deffc

File tree

12 files changed

+331
-22
lines changed

12 files changed

+331
-22
lines changed

paperbanana/agents/base.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@ class BaseAgent(ABC):
2020
a specific role in the generation process.
2121
"""
2222

23-
def __init__(self, vlm_provider: VLMProvider, prompt_dir: str = "prompts"):
23+
def __init__(
24+
self,
25+
vlm_provider: VLMProvider,
26+
prompt_dir: str = "prompts",
27+
prompt_recorder: Any | None = None,
28+
):
2429
self.vlm = vlm_provider
2530
self.prompt_dir = Path(prompt_dir)
31+
self._prompt_recorder = prompt_recorder
2632

2733
@property
2834
@abstractmethod
@@ -50,5 +56,23 @@ def load_prompt(self, diagram_type: str = "diagram") -> str:
5056
return path.read_text(encoding="utf-8")
5157

5258
def format_prompt(self, template: str, **kwargs: Any) -> str:
53-
"""Format a prompt template with the given values."""
54-
return template.format(**kwargs)
59+
"""Format a prompt template with the given values.
60+
61+
If a prompt recorder is configured, this method will write the formatted
62+
prompt to the active run directory.
63+
"""
64+
# Reserved internal argument (not forwarded into template.format()).
65+
prompt_label = kwargs.pop("prompt_label", None)
66+
67+
formatted = template.format(**kwargs)
68+
if self._prompt_recorder is not None:
69+
try:
70+
self._prompt_recorder.record(
71+
agent_name=self.agent_name,
72+
label=str(prompt_label) if prompt_label else None,
73+
prompt=formatted,
74+
)
75+
except Exception:
76+
# Recording is best-effort; do not break generation on I/O issues.
77+
logger.warning("Prompt recording failed", agent=self.agent_name, label=prompt_label)
78+
return formatted

paperbanana/agents/critic.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import json
6+
import re
67
from typing import Optional
78

89
import structlog
@@ -22,8 +23,10 @@ class CriticAgent(BaseAgent):
2223
faithfulness, conciseness, readability, and aesthetic issues.
2324
"""
2425

25-
def __init__(self, vlm_provider: VLMProvider, prompt_dir: str = "prompts"):
26-
super().__init__(vlm_provider, prompt_dir)
26+
def __init__(
27+
self, vlm_provider: VLMProvider, prompt_dir: str = "prompts", prompt_recorder=None
28+
):
29+
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)
2730

2831
@property
2932
def agent_name(self) -> str:
@@ -56,8 +59,9 @@ async def run(
5659

5760
prompt_type = "diagram" if diagram_type == DiagramType.METHODOLOGY else "plot"
5861
template = self.load_prompt(prompt_type)
59-
prompt = self.format_prompt(
60-
template,
62+
prompt_label = self._prompt_label_from_image_path(image_path) or "critic"
63+
# Build prompt manually so we record once after appending user_feedback.
64+
prompt = template.format(
6165
source_context=source_context,
6266
caption=caption,
6367
description=description,
@@ -68,6 +72,17 @@ async def run(
6872
f"\n\nAdditional user feedback to consider in your evaluation:\n{user_feedback}"
6973
)
7074

75+
# Record the exact prompt sent to the model (including user_feedback in continue-run flows)
76+
if self._prompt_recorder is not None:
77+
try:
78+
self._prompt_recorder.record(
79+
agent_name=self.agent_name,
80+
label=prompt_label,
81+
prompt=prompt,
82+
)
83+
except Exception:
84+
logger.warning("Prompt recording failed", agent=self.agent_name, label=prompt_label)
85+
7186
logger.info("Running critic agent", image_path=image_path)
7287

7388
response = await self.vlm.generate(
@@ -86,6 +101,14 @@ async def run(
86101
)
87102
return critique
88103

104+
@staticmethod
105+
def _prompt_label_from_image_path(image_path: str) -> str | None:
106+
"""Best-effort label (e.g. critic_iter_3) derived from output image filename."""
107+
m = re.search(r"(?:diagram|plot)_iter_(\d+)\.", image_path)
108+
if not m:
109+
return None
110+
return f"critic_iter_{m.group(1)}"
111+
89112
def _parse_response(self, response: str) -> CritiqueResult:
90113
"""Parse the VLM response into a CritiqueResult."""
91114
try:

paperbanana/agents/optimizer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ class InputOptimizerAgent(BaseAgent):
2626
- Caption sharpener: converts vague intent into precise visual specification
2727
"""
2828

29-
def __init__(self, vlm_provider: VLMProvider, prompt_dir: str = "prompts"):
30-
super().__init__(vlm_provider, prompt_dir)
29+
def __init__(
30+
self, vlm_provider: VLMProvider, prompt_dir: str = "prompts", prompt_recorder=None
31+
):
32+
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)
3133

3234
@property
3335
def agent_name(self) -> str:
@@ -57,11 +59,13 @@ async def run(
5759

5860
context_prompt = self.format_prompt(
5961
context_template,
62+
prompt_label="context_enricher",
6063
source_context=source_context,
6164
caption=caption,
6265
)
6366
caption_prompt = self.format_prompt(
6467
caption_template,
68+
prompt_label="caption_sharpener",
6569
source_context=source_context,
6670
caption=caption,
6771
)

paperbanana/agents/planner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ class PlannerAgent(BaseAgent):
3030
can render. Matches paper equation 4: P = VLM_plan(S, C, {(S_i, C_i, I_i)}).
3131
"""
3232

33-
def __init__(self, vlm_provider: VLMProvider, prompt_dir: str = "prompts"):
34-
super().__init__(vlm_provider, prompt_dir)
33+
def __init__(
34+
self, vlm_provider: VLMProvider, prompt_dir: str = "prompts", prompt_recorder=None
35+
):
36+
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)
3537

3638
@property
3739
def agent_name(self) -> str:
@@ -70,6 +72,7 @@ async def run(
7072
ratios_str = ", ".join(supported_ratios) if supported_ratios else "1:1, 16:9"
7173
prompt = self.format_prompt(
7274
template,
75+
prompt_label="planner",
7376
source_context=source_context,
7477
caption=caption,
7578
examples=examples_text,

paperbanana/agents/retriever.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ class RetrieverAgent(BaseAgent):
2020
reference examples are most useful for generating the target diagram.
2121
"""
2222

23-
def __init__(self, vlm_provider: VLMProvider, prompt_dir: str = "prompts"):
24-
super().__init__(vlm_provider, prompt_dir)
23+
def __init__(
24+
self, vlm_provider: VLMProvider, prompt_dir: str = "prompts", prompt_recorder=None
25+
):
26+
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)
2527

2628
@property
2729
def agent_name(self) -> str:
@@ -68,6 +70,7 @@ async def run(
6870
template = self.load_prompt(prompt_type)
6971
prompt = self.format_prompt(
7072
template,
73+
prompt_label="retriever",
7174
source_context=source_context,
7275
caption=caption,
7376
candidates=candidates_text,

paperbanana/agents/stylist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def __init__(
2323
vlm_provider: VLMProvider,
2424
guidelines: str = "",
2525
prompt_dir: str = "prompts",
26+
prompt_recorder=None,
2627
):
27-
super().__init__(vlm_provider, prompt_dir)
28+
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)
2829
self.guidelines = guidelines
2930

3031
@property
@@ -59,6 +60,7 @@ async def run(
5960
template = self.load_prompt(prompt_type)
6061
prompt = self.format_prompt(
6162
template,
63+
prompt_label="stylist",
6264
description=description,
6365
guidelines=style_guidelines,
6466
source_context=source_context,

paperbanana/agents/visualizer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def __init__(
3333
vlm_provider: VLMProvider,
3434
prompt_dir: str = "prompts",
3535
output_dir: str = "outputs",
36+
prompt_recorder=None,
3637
):
37-
super().__init__(vlm_provider, prompt_dir)
38+
super().__init__(vlm_provider, prompt_dir, prompt_recorder=prompt_recorder)
3839
self.image_gen = image_gen
3940
self.output_dir = Path(output_dir)
4041

@@ -89,7 +90,11 @@ async def _generate_diagram(
8990
) -> str:
9091
"""Generate a methodology diagram using the image generation model."""
9192
template = self.load_prompt("diagram")
92-
prompt = self.format_prompt(template, description=description)
93+
prompt = self.format_prompt(
94+
template,
95+
prompt_label=f"visualizer_diagram_iter_{iteration}",
96+
description=description,
97+
)
9398

9499
logger.info("Generating diagram image", iteration=iteration)
95100

@@ -144,7 +149,11 @@ async def _generate_plot(
144149

145150
# Load and format the plot visualizer prompt template
146151
template = self.load_prompt("plot")
147-
code_prompt = self.format_prompt(template, description=full_description)
152+
code_prompt = self.format_prompt(
153+
template,
154+
prompt_label=f"visualizer_plot_iter_{iteration}",
155+
description=full_description,
156+
)
148157

149158
logger.info("Generating plot code", iteration=iteration)
150159

paperbanana/cli.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ def generate(
8282
help="Output image format (png, jpeg, or webp)",
8383
),
8484
config: Optional[str] = typer.Option(None, "--config", help="Path to config YAML file"),
85+
save_prompts: Optional[bool] = typer.Option(
86+
None,
87+
"--save-prompts/--no-save-prompts",
88+
help="Save formatted prompts into the run directory (for debugging)",
89+
),
8590
dry_run: bool = typer.Option(
8691
False,
8792
"--dry-run",
@@ -165,6 +170,8 @@ def generate(
165170
overrides["max_iterations"] = max_iterations
166171
if optimize:
167172
overrides["optimize_inputs"] = True
173+
if save_prompts is not None:
174+
overrides["save_prompts"] = save_prompts
168175
if output:
169176
overrides["output_dir"] = str(Path(output).parent)
170177
overrides["output_format"] = format
@@ -488,6 +495,11 @@ def plot(
488495
auto: bool = typer.Option(
489496
False, "--auto", help="Let critic loop until satisfied (max 30 iterations)"
490497
),
498+
save_prompts: Optional[bool] = typer.Option(
499+
None,
500+
"--save-prompts/--no-save-prompts",
501+
help="Save formatted prompts into the run directory (for debugging)",
502+
),
491503
):
492504
"""Generate a statistical plot from data."""
493505
if format not in ("png", "jpeg", "webp"):
@@ -527,6 +539,7 @@ def plot(
527539
output_format=format,
528540
optimize_inputs=optimize,
529541
auto_refine=auto,
542+
save_prompts=True if save_prompts is None else save_prompts,
530543
)
531544

532545
gen_input = GenerationInput(

paperbanana/core/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class Settings(BaseSettings):
8888
output_dir: str = "outputs"
8989
output_format: OutputFormat = "png"
9090
save_iterations: bool = True
91+
save_prompts: bool = True
9192

9293
# API Keys (loaded from environment)
9394
google_api_key: Optional[str] = Field(default=None, alias="GOOGLE_API_KEY")
@@ -208,6 +209,7 @@ def _flatten_yaml(config: dict, prefix: str = "") -> dict:
208209
"output.dir": "output_dir",
209210
"output.format": "output_format",
210211
"output.save_iterations": "save_iterations",
212+
"output.save_prompts": "save_prompts",
211213
}
212214

213215
def _recurse(d: dict, prefix: str = "") -> None:

paperbanana/core/pipeline.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from paperbanana.agents.stylist import StylistAgent
1717
from paperbanana.agents.visualizer import VisualizerAgent
1818
from paperbanana.core.config import Settings
19+
from paperbanana.core.prompt_recorder import PromptRecorder
1920
from paperbanana.core.types import (
2021
DiagramType,
2122
GenerationInput,
@@ -117,6 +118,11 @@ def __init__(
117118
if self.settings.skip_ssl_verification:
118119
_apply_ssl_skip()
119120

121+
# Prompt recorder (writes formatted prompts to outputs/<run_id>/prompts/)
122+
self._prompt_recorder = None
123+
if self.settings.save_prompts:
124+
self._prompt_recorder = PromptRecorder(run_dir_provider=lambda: self._run_dir)
125+
120126
# Initialize providers
121127
if vlm_client is not None:
122128
# Demo mode: use provided clients
@@ -145,19 +151,31 @@ def __init__(
145151

146152
# Initialize agents
147153
prompt_dir = self._find_prompt_dir()
148-
self.optimizer = InputOptimizerAgent(self._vlm, prompt_dir=prompt_dir)
149-
self.retriever = RetrieverAgent(self._vlm, prompt_dir=prompt_dir)
150-
self.planner = PlannerAgent(self._vlm, prompt_dir=prompt_dir)
154+
self.optimizer = InputOptimizerAgent(
155+
self._vlm, prompt_dir=prompt_dir, prompt_recorder=self._prompt_recorder
156+
)
157+
self.retriever = RetrieverAgent(
158+
self._vlm, prompt_dir=prompt_dir, prompt_recorder=self._prompt_recorder
159+
)
160+
self.planner = PlannerAgent(
161+
self._vlm, prompt_dir=prompt_dir, prompt_recorder=self._prompt_recorder
162+
)
151163
self.stylist = StylistAgent(
152-
self._vlm, guidelines=self._methodology_guidelines, prompt_dir=prompt_dir
164+
self._vlm,
165+
guidelines=self._methodology_guidelines,
166+
prompt_dir=prompt_dir,
167+
prompt_recorder=self._prompt_recorder,
153168
)
154169
self.visualizer = VisualizerAgent(
155170
self._image_gen,
156171
self._vlm,
157172
prompt_dir=prompt_dir,
158173
output_dir=str(self._run_dir),
174+
prompt_recorder=self._prompt_recorder,
175+
)
176+
self.critic = CriticAgent(
177+
self._vlm, prompt_dir=prompt_dir, prompt_recorder=self._prompt_recorder
159178
)
160-
self.critic = CriticAgent(self._vlm, prompt_dir=prompt_dir)
161179

162180
logger.info(
163181
"Pipeline initialized",

0 commit comments

Comments
 (0)