2828import json
2929import textwrap
3030
31- from flyte ._image import PythonWheels , DIST_FOLDER
3231from flyteplugins .anthropic import function_tool , run_agent
3332
3433import flyte
3534import flyte .report
35+ from flyte ._image import DIST_FOLDER , PythonWheels
3636
3737# ---------------------------------------------------------------------------
3838# Environments
4242 "python-sandbox" ,
4343 resources = flyte .Resources (cpu = 1 , memory = "1Gi" ),
4444 image = (
45- flyte .Image .from_debian_base (python_version = (3 , 13 ))
46- .with_pip_packages ("numpy" , "pandas" , "scikit-learn" , "matplotlib" )
45+ flyte .Image .from_debian_base (python_version = (3 , 13 )).with_pip_packages (
46+ "numpy" , "pandas" , "scikit-learn" , "matplotlib"
47+ )
4748 ),
4849)
4950
6667# Token-budget constants (character limits, ~4 chars ≈ 1 token)
6768# ---------------------------------------------------------------------------
6869
69- MAX_SEARCH_SNIPPET_CHARS = 400 # per search result content snippet
70- MAX_SEARCH_RESULTS = 3 # default results per web search call
71- MAX_EXEC_OUTPUT_CHARS = 3000 # stdout/stderr cap from execute_python
72- MAX_SUMMARY_CHARS_FOR_CRITIC = 600 # per sub-result when evaluating quality
70+ MAX_SEARCH_SNIPPET_CHARS = 400 # per search result content snippet
71+ MAX_SEARCH_RESULTS = 3 # default results per web search call
72+ MAX_EXEC_OUTPUT_CHARS = 3000 # stdout/stderr cap from execute_python
73+ MAX_SUMMARY_CHARS_FOR_CRITIC = 600 # per sub-result when evaluating quality
7374MAX_SUMMARY_CHARS_FOR_SYNTHESIS = 1500 # per sub-result when synthesizing
7475
7576
@@ -94,6 +95,7 @@ async def web_search(query: str, max_results: int = MAX_SEARCH_RESULTS) -> str:
9495 snippets (each capped to ~400 chars to keep token usage low).
9596 """
9697 import os
98+
9799 from tavily import TavilyClient
98100
99101 client = TavilyClient (api_key = os .environ ["TAVILY_API_KEY" ])
@@ -104,11 +106,13 @@ async def web_search(query: str, max_results: int = MAX_SEARCH_RESULTS) -> str:
104106 )
105107 results = []
106108 for r in response .get ("results" , []):
107- results .append ({
108- "title" : r .get ("title" , "" ),
109- "url" : r .get ("url" , "" ),
110- "content" : _truncate (r .get ("content" , "" ), MAX_SEARCH_SNIPPET_CHARS ),
111- })
109+ results .append (
110+ {
111+ "title" : r .get ("title" , "" ),
112+ "url" : r .get ("url" , "" ),
113+ "content" : _truncate (r .get ("content" , "" ), MAX_SEARCH_SNIPPET_CHARS ),
114+ }
115+ )
112116 # Compact JSON (no indent) to save tokens
113117 return json .dumps (results , separators = ("," , ":" ))
114118
@@ -159,10 +163,7 @@ async def execute_python(code: str) -> str:
159163 fig = plt .figure (num )
160164 w , h = fig .get_size_inches ()
161165 n_axes = len (fig .axes )
162- output += (
163- f"\n [Generated Figure { i } : "
164- f"{ w :.0f} x{ h :.0f} in, { n_axes } axes]"
165- )
166+ output += f"\n [Generated Figure { i } : { w :.0f} x{ h :.0f} in, { n_axes } axes]"
166167 if fig_nums :
167168 plt .close ("all" )
168169 except Exception :
@@ -326,9 +327,7 @@ async def build_report_section(
326327 sources_html = ""
327328 if sources .strip ():
328329 source_items = [
329- f'<li><a href="{ s .strip ()} " target="_blank">{ s .strip ()} </a></li>'
330- for s in sources .split ("," )
331- if s .strip ()
330+ f'<li><a href="{ s .strip ()} " target="_blank">{ s .strip ()} </a></li>' for s in sources .split ("," ) if s .strip ()
332331 ]
333332 sources_html = f"""
334333 <div class="sources">
@@ -413,7 +412,7 @@ async def build_report_section(
413412]
414413
415414
416- @flyte . trace
415+ @agent_env . task
417416async def run_research_sub_agent (question : str , name : str ) -> dict :
418417 """Run a single research sub-agent for one sub-question."""
419418 with flyte .group (f"researcher-{ name } " ):
@@ -677,9 +676,7 @@ async def evaluate_quality(
677676 comprehensibility_score, groundedness_score, critique, follow_up_questions.
678677 """
679678 sub_summaries = "\n \n " .join (
680- f"### { r ['question' ]} \n "
681- f"{ _truncate (r ['summary' ], MAX_SUMMARY_CHARS_FOR_CRITIC )} "
682- for r in sub_results
679+ f"### { r ['question' ]} \n { _truncate (r ['summary' ], MAX_SUMMARY_CHARS_FOR_CRITIC )} " for r in sub_results
683680 )
684681 raw = await run_agent (
685682 prompt = (
@@ -719,7 +716,7 @@ def _build_section_htmls(sub_results: list[dict]) -> list[str]:
719716 content_html = md .markdown (result ["summary" ], extensions = md_extensions )
720717 fallback_section = f"""
721718 <div class="report-section">
722- <h3>{ result [' question' ]} </h3>
719+ <h3>{ result [" question" ]} </h3>
723720 <div class="section-content">
724721 { content_html }
725722 </div>
@@ -746,8 +743,7 @@ async def synthesize_summaries(
746743 if previous_summary :
747744 # Refinement mode
748745 new_summaries = "\n \n " .join (
749- f"### Follow-up: { r ['question' ]} \n "
750- f"{ _truncate (r ['summary' ], MAX_SUMMARY_CHARS_FOR_SYNTHESIS )} "
746+ f"### Follow-up: { r ['question' ]} \n { _truncate (r ['summary' ], MAX_SUMMARY_CHARS_FOR_SYNTHESIS )} "
751747 for r in sub_results
752748 )
753749 prompt = (
@@ -760,8 +756,7 @@ async def synthesize_summaries(
760756 else :
761757 # Initial synthesis
762758 summaries_text = "\n \n " .join (
763- f"### Sub-question: { r ['question' ]} \n "
764- f"{ _truncate (r ['summary' ], MAX_SUMMARY_CHARS_FOR_SYNTHESIS )} "
759+ f"### Sub-question: { r ['question' ]} \n { _truncate (r ['summary' ], MAX_SUMMARY_CHARS_FOR_SYNTHESIS )} "
765760 for r in sub_results
766761 )
767762 prompt = (
@@ -790,10 +785,7 @@ async def decompose_query(query: str) -> list[dict[str, str]]:
790785 Returns a list of dicts with "name" and "question" keys.
791786 """
792787 raw = await run_agent (
793- prompt = (
794- f"Decompose this research query into 3-5 specific "
795- f"sub-questions:\n \n { query } "
796- ),
788+ prompt = (f"Decompose this research query into 3-5 specific sub-questions:\n \n { query } " ),
797789 system = DECOMPOSE_SYSTEM_PROMPT ,
798790 model = "claude-sonnet-4-20250514" ,
799791 max_tokens = 1024 ,
@@ -808,15 +800,19 @@ async def decompose_query(query: str) -> list[dict[str, str]]:
808800 sub_questions : list [dict [str , str ]] = []
809801 for idx , item in enumerate (parsed , start = 1 ):
810802 if isinstance (item , dict ):
811- sub_questions .append ({
812- "name" : item .get ("name" , f"sub-q-{ idx } " ),
813- "question" : item .get ("question" , str (item )),
814- })
803+ sub_questions .append (
804+ {
805+ "name" : item .get ("name" , f"sub-q-{ idx } " ),
806+ "question" : item .get ("question" , str (item )),
807+ }
808+ )
815809 else :
816- sub_questions .append ({
817- "name" : f"sub-q-{ idx } " ,
818- "question" : str (item ),
819- })
810+ sub_questions .append (
811+ {
812+ "name" : f"sub-q-{ idx } " ,
813+ "question" : str (item ),
814+ }
815+ )
820816 except json .JSONDecodeError :
821817 # Fallback: treat the response as a single question
822818 sub_questions = [{"name" : "sub-q-1" , "question" : raw }]
@@ -847,10 +843,7 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
847843
848844 # --- Step 2: Fan out parallel research sub-agents ---
849845 with flyte .group ("parallel-research" ):
850- tasks = [
851- run_research_sub_agent (sq ["question" ], sq ["name" ])
852- for sq in sub_questions
853- ]
846+ tasks = [run_research_sub_agent (sq ["question" ], sq ["name" ]) for sq in sub_questions ]
854847 sub_results = list (await asyncio .gather (* tasks ))
855848
856849 # --- Step 3: Synthesize into executive summary ---
@@ -864,7 +857,9 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
864857 for refinement_round in range (max_refinements ):
865858 with flyte .group (f"evaluate-{ refinement_round } " ):
866859 evaluation = await evaluate_quality (
867- query , executive_summary , latest_sub_results ,
860+ query ,
861+ executive_summary ,
862+ latest_sub_results ,
868863 )
869864
870865 print (
@@ -886,10 +881,7 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
886881 print ("Critic found issues but no follow-up questions — finalizing." )
887882 break
888883
889- print (
890- f"Quality check failed (round { refinement_round + 1 } ). "
891- f"Critique: { critique } "
892- )
884+ print (f"Quality check failed (round { refinement_round + 1 } ). Critique: { critique } " )
893885 print (f"Follow-up questions: { follow_ups } " )
894886
895887 # --- Step 4a: Targeted follow-up research ---
@@ -917,10 +909,7 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
917909 critique = critique ,
918910 )
919911 else :
920- print (
921- f"Reached maximum refinement iterations ({ max_refinements } ). "
922- f"Finalizing with best available output."
923- )
912+ print (f"Reached maximum refinement iterations ({ max_refinements } ). Finalizing with best available output." )
924913
925914 # --- Step 5: Build and publish final HTML report ---
926915 section_htmls = _build_section_htmls (sub_results )
0 commit comments