@@ -92,6 +92,41 @@ def generate(
9292 "--auto-download-data" ,
9393 help = "Auto-download expanded reference set (~257MB) on first run if not cached" ,
9494 ),
95+ exemplar_retrieval : bool = typer .Option (
96+ False ,
97+ "--exemplar-retrieval" ,
98+ help = "Enable external exemplar retrieval before planning" ,
99+ ),
100+ exemplar_endpoint : Optional [str ] = typer .Option (
101+ None ,
102+ "--exemplar-endpoint" ,
103+ help = "External exemplar retrieval endpoint URL" ,
104+ ),
105+ exemplar_mode : Optional [str ] = typer .Option (
106+ None ,
107+ "--exemplar-mode" ,
108+ help = "Exemplar retrieval mode: external_then_rerank or external_only" ,
109+ ),
110+ exemplar_top_k : Optional [int ] = typer .Option (
111+ None ,
112+ "--exemplar-top-k" ,
113+ help = "Top-k exemplars requested from external retriever" ,
114+ ),
115+ exemplar_timeout : Optional [float ] = typer .Option (
116+ None ,
117+ "--exemplar-timeout" ,
118+ help = "External exemplar retrieval timeout (seconds)" ,
119+ ),
120+ exemplar_retries : Optional [int ] = typer .Option (
121+ None ,
122+ "--exemplar-retries" ,
123+ help = "Retry attempts for external exemplar retrieval on transient errors" ,
124+ ),
125+ seed : Optional [int ] = typer .Option (
126+ None ,
127+ "--seed" ,
128+ help = "Random seed for reproducible image generation" ,
129+ ),
95130 verbose : bool = typer .Option (
96131 False , "--verbose" , "-v" , help = "Show detailed agent progress and timing"
97132 ),
@@ -104,6 +139,11 @@ def generate(
104139 if feedback and not continue_run and not continue_last :
105140 console .print ("[red]Error: --feedback requires --continue or --continue-run[/red]" )
106141 raise typer .Exit (1 )
142+ if exemplar_mode and exemplar_mode not in ("external_then_rerank" , "external_only" ):
143+ console .print (
144+ "[red]Error: --exemplar-mode must be external_then_rerank or external_only[/red]"
145+ )
146+ raise typer .Exit (1 )
107147
108148 configure_logging (verbose = verbose )
109149
@@ -128,6 +168,20 @@ def generate(
128168 if output :
129169 overrides ["output_dir" ] = str (Path (output ).parent )
130170 overrides ["output_format" ] = format
171+ if exemplar_retrieval :
172+ overrides ["exemplar_retrieval_enabled" ] = True
173+ if exemplar_endpoint :
174+ overrides ["exemplar_retrieval_endpoint" ] = exemplar_endpoint
175+ if exemplar_mode :
176+ overrides ["exemplar_retrieval_mode" ] = exemplar_mode
177+ if exemplar_top_k is not None :
178+ overrides ["exemplar_retrieval_top_k" ] = exemplar_top_k
179+ if exemplar_timeout is not None :
180+ overrides ["exemplar_retrieval_timeout_seconds" ] = exemplar_timeout
181+ if exemplar_retries is not None :
182+ overrides ["exemplar_retrieval_max_retries" ] = exemplar_retries
183+ if seed is not None :
184+ overrides ["seed" ] = seed
131185
132186 if config :
133187 settings = Settings .from_yaml (config , ** overrides )
@@ -615,6 +669,155 @@ async def _run():
615669 console .print (f"\n [bold]{ dim } [/bold]: { result .reasoning } " )
616670
617671
672+ @app .command ("ablate-retrieval" )
673+ def ablate_retrieval (
674+ input : str = typer .Option (..., "--input" , "-i" , help = "Path to methodology text file" ),
675+ caption : str = typer .Option (
676+ ..., "--caption" , "-c" , help = "Figure caption / communicative intent"
677+ ),
678+ exemplar_endpoint : str = typer .Option (
679+ ..., "--exemplar-endpoint" , help = "External exemplar retrieval endpoint URL"
680+ ),
681+ top_k : str = typer .Option (
682+ "1,3,5" , "--top-k" , help = "Comma-separated top-k values (e.g., 1,3,5)"
683+ ),
684+ seed : Optional [int ] = typer .Option (
685+ None ,
686+ "--seed" ,
687+ help = "Random seed used for all variants (default: 42 if omitted)" ,
688+ ),
689+ exemplar_retries : Optional [int ] = typer .Option (
690+ None ,
691+ "--exemplar-retries" ,
692+ help = "Retry attempts for external exemplar retrieval on transient errors" ,
693+ ),
694+ reference : Optional [str ] = typer .Option (
695+ None ,
696+ "--reference" ,
697+ "-r" ,
698+ help = "Optional human reference image for judge-based preference proxy" ,
699+ ),
700+ output_report : Optional [str ] = typer .Option (
701+ None ,
702+ "--output-report" ,
703+ "-o" ,
704+ help = "Output JSON report path (default: outputs/retrieval_ablation_<runid>.json)" ,
705+ ),
706+ config : Optional [str ] = typer .Option (None , "--config" , help = "Path to config YAML file" ),
707+ vlm_provider : Optional [str ] = typer .Option (
708+ None , "--vlm-provider" , help = "VLM provider override for generation and judge"
709+ ),
710+ image_provider : Optional [str ] = typer .Option (
711+ None , "--image-provider" , help = "Image generation provider override"
712+ ),
713+ verbose : bool = typer .Option (
714+ False , "--verbose" , "-v" , help = "Show detailed agent progress and timing"
715+ ),
716+ ):
717+ """Run baseline vs retrieval ablation (k sweep) and save a JSON report."""
718+ configure_logging (verbose = verbose )
719+
720+ input_path = Path (input )
721+ if not input_path .exists ():
722+ console .print (f"[red]Error: Input file not found: { input } [/red]" )
723+ raise typer .Exit (1 )
724+
725+ reference_path : Optional [Path ] = None
726+ if reference :
727+ reference_path = Path (reference )
728+ if not reference_path .exists ():
729+ console .print (f"[red]Error: Reference image not found: { reference } [/red]" )
730+ raise typer .Exit (1 )
731+
732+ from dotenv import load_dotenv
733+
734+ load_dotenv ()
735+
736+ from paperbanana .core .types import DiagramType , GenerationInput
737+ from paperbanana .core .utils import generate_run_id
738+ from paperbanana .evaluation .retrieval_ablation import (
739+ RetrievalAblationRunner ,
740+ parse_top_k_values ,
741+ )
742+
743+ try :
744+ k_values = parse_top_k_values (top_k )
745+ except ValueError as e :
746+ console .print (f"[red]Error: { e } [/red]" )
747+ raise typer .Exit (1 )
748+
749+ overrides = {
750+ "exemplar_retrieval_endpoint" : exemplar_endpoint ,
751+ "exemplar_retrieval_enabled" : True ,
752+ }
753+ if vlm_provider :
754+ overrides ["vlm_provider" ] = vlm_provider
755+ if image_provider :
756+ overrides ["image_provider" ] = image_provider
757+ if seed is not None :
758+ overrides ["seed" ] = seed
759+ if exemplar_retries is not None :
760+ overrides ["exemplar_retrieval_max_retries" ] = exemplar_retries
761+
762+ if config :
763+ settings = Settings .from_yaml (config , ** overrides )
764+ else :
765+ settings = Settings (** overrides )
766+
767+ gen_input = GenerationInput (
768+ source_context = input_path .read_text (encoding = "utf-8" ),
769+ communicative_intent = caption ,
770+ diagram_type = DiagramType .METHODOLOGY ,
771+ )
772+
773+ runner = RetrievalAblationRunner (
774+ settings ,
775+ reference_image_path = str (reference_path ) if reference_path else None ,
776+ )
777+
778+ async def _run ():
779+ return await runner .run (gen_input , top_k_values = k_values )
780+
781+ console .print (
782+ Panel .fit (
783+ f"[bold]PaperBanana[/bold] - Retrieval Ablation\n \n "
784+ f"Top-k sweep: { k_values } \n "
785+ f"Endpoint: { exemplar_endpoint } \n "
786+ f"Seed: { settings .seed if settings .seed is not None else 42 } \n "
787+ f"Reference: { reference_path if reference_path else 'none' } " ,
788+ border_style = "magenta" ,
789+ )
790+ )
791+
792+ report = asyncio .run (_run ())
793+
794+ default_report_path = Path (settings .output_dir ) / f"retrieval_ablation_{ generate_run_id ()} .json"
795+ report_path = Path (output_report ) if output_report else default_report_path
796+ saved_path = runner .save_report (report , report_path )
797+
798+ summary = report .summary
799+ human_pref_line = ""
800+ if summary .get ("best_human_preference_variant" ) is not None :
801+ human_pref_line = (
802+ f"Best human preference: { summary .get ('best_human_preference_variant' )} "
803+ f"({ summary .get ('best_human_preference_score' )} )\n "
804+ )
805+ console .print (
806+ Panel .fit (
807+ "[bold]Ablation Summary[/bold]\n \n "
808+ f"Best alignment: { summary .get ('best_alignment_variant' )} "
809+ f"({ summary .get ('best_alignment_score' )} )\n "
810+ f"{ human_pref_line } "
811+ f"Fastest: { summary .get ('fastest_variant' )} "
812+ f"({ summary .get ('fastest_total_seconds' )} s)\n "
813+ f"Fewest iterations: { summary .get ('fewest_iterations_variant' )} "
814+ f"({ summary .get ('fewest_iterations' )} )\n \n "
815+ f"Report: [bold]{ saved_path } [/bold]" ,
816+ border_style = "cyan" ,
817+ )
818+ )
819+
820+
618821# ── Data subcommands ──────────────────────────────────────────────
619822
620823
0 commit comments