Skip to content

Commit fd328d2

Browse files
committed
lint
Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>
1 parent 903d67c commit fd328d2

File tree

2 files changed

+46
-56
lines changed

2 files changed

+46
-56
lines changed

examples/genai/anthropic_deep_research_agent.py

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
import json
2929
import textwrap
3030

31-
from flyte._image import PythonWheels, DIST_FOLDER
3231
from flyteplugins.anthropic import function_tool, run_agent
3332

3433
import flyte
3534
import flyte.report
35+
from flyte._image import DIST_FOLDER, PythonWheels
3636

3737
# ---------------------------------------------------------------------------
3838
# Environments
@@ -42,8 +42,9 @@
4242
"python-sandbox",
4343
resources=flyte.Resources(cpu=1, memory="1Gi"),
4444
image=(
45-
flyte.Image.from_debian_base(python_version=(3, 13))
46-
.with_pip_packages("numpy", "pandas", "scikit-learn", "matplotlib")
45+
flyte.Image.from_debian_base(python_version=(3, 13)).with_pip_packages(
46+
"numpy", "pandas", "scikit-learn", "matplotlib"
47+
)
4748
),
4849
)
4950

@@ -66,10 +67,10 @@
6667
# Token-budget constants (character limits, ~4 chars ≈ 1 token)
6768
# ---------------------------------------------------------------------------
6869

69-
MAX_SEARCH_SNIPPET_CHARS = 400 # per search result content snippet
70-
MAX_SEARCH_RESULTS = 3 # default results per web search call
71-
MAX_EXEC_OUTPUT_CHARS = 3000 # stdout/stderr cap from execute_python
72-
MAX_SUMMARY_CHARS_FOR_CRITIC = 600 # per sub-result when evaluating quality
70+
MAX_SEARCH_SNIPPET_CHARS = 400 # per search result content snippet
71+
MAX_SEARCH_RESULTS = 3 # default results per web search call
72+
MAX_EXEC_OUTPUT_CHARS = 3000 # stdout/stderr cap from execute_python
73+
MAX_SUMMARY_CHARS_FOR_CRITIC = 600 # per sub-result when evaluating quality
7374
MAX_SUMMARY_CHARS_FOR_SYNTHESIS = 1500 # per sub-result when synthesizing
7475

7576

@@ -94,6 +95,7 @@ async def web_search(query: str, max_results: int = MAX_SEARCH_RESULTS) -> str:
9495
snippets (each capped to ~400 chars to keep token usage low).
9596
"""
9697
import os
98+
9799
from tavily import TavilyClient
98100

99101
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
@@ -104,11 +106,13 @@ async def web_search(query: str, max_results: int = MAX_SEARCH_RESULTS) -> str:
104106
)
105107
results = []
106108
for r in response.get("results", []):
107-
results.append({
108-
"title": r.get("title", ""),
109-
"url": r.get("url", ""),
110-
"content": _truncate(r.get("content", ""), MAX_SEARCH_SNIPPET_CHARS),
111-
})
109+
results.append(
110+
{
111+
"title": r.get("title", ""),
112+
"url": r.get("url", ""),
113+
"content": _truncate(r.get("content", ""), MAX_SEARCH_SNIPPET_CHARS),
114+
}
115+
)
112116
# Compact JSON (no indent) to save tokens
113117
return json.dumps(results, separators=(",", ":"))
114118

@@ -159,10 +163,7 @@ async def execute_python(code: str) -> str:
159163
fig = plt.figure(num)
160164
w, h = fig.get_size_inches()
161165
n_axes = len(fig.axes)
162-
output += (
163-
f"\n[Generated Figure {i}: "
164-
f"{w:.0f}x{h:.0f} in, {n_axes} axes]"
165-
)
166+
output += f"\n[Generated Figure {i}: {w:.0f}x{h:.0f} in, {n_axes} axes]"
166167
if fig_nums:
167168
plt.close("all")
168169
except Exception:
@@ -326,9 +327,7 @@ async def build_report_section(
326327
sources_html = ""
327328
if sources.strip():
328329
source_items = [
329-
f'<li><a href="{s.strip()}" target="_blank">{s.strip()}</a></li>'
330-
for s in sources.split(",")
331-
if s.strip()
330+
f'<li><a href="{s.strip()}" target="_blank">{s.strip()}</a></li>' for s in sources.split(",") if s.strip()
332331
]
333332
sources_html = f"""
334333
<div class="sources">
@@ -413,7 +412,7 @@ async def build_report_section(
413412
]
414413

415414

416-
@flyte.trace
415+
@agent_env.task
417416
async def run_research_sub_agent(question: str, name: str) -> dict:
418417
"""Run a single research sub-agent for one sub-question."""
419418
with flyte.group(f"researcher-{name}"):
@@ -677,9 +676,7 @@ async def evaluate_quality(
677676
comprehensibility_score, groundedness_score, critique, follow_up_questions.
678677
"""
679678
sub_summaries = "\n\n".join(
680-
f"### {r['question']}\n"
681-
f"{_truncate(r['summary'], MAX_SUMMARY_CHARS_FOR_CRITIC)}"
682-
for r in sub_results
679+
f"### {r['question']}\n{_truncate(r['summary'], MAX_SUMMARY_CHARS_FOR_CRITIC)}" for r in sub_results
683680
)
684681
raw = await run_agent(
685682
prompt=(
@@ -719,7 +716,7 @@ def _build_section_htmls(sub_results: list[dict]) -> list[str]:
719716
content_html = md.markdown(result["summary"], extensions=md_extensions)
720717
fallback_section = f"""
721718
<div class="report-section">
722-
<h3>{result['question']}</h3>
719+
<h3>{result["question"]}</h3>
723720
<div class="section-content">
724721
{content_html}
725722
</div>
@@ -746,8 +743,7 @@ async def synthesize_summaries(
746743
if previous_summary:
747744
# Refinement mode
748745
new_summaries = "\n\n".join(
749-
f"### Follow-up: {r['question']}\n"
750-
f"{_truncate(r['summary'], MAX_SUMMARY_CHARS_FOR_SYNTHESIS)}"
746+
f"### Follow-up: {r['question']}\n{_truncate(r['summary'], MAX_SUMMARY_CHARS_FOR_SYNTHESIS)}"
751747
for r in sub_results
752748
)
753749
prompt = (
@@ -760,8 +756,7 @@ async def synthesize_summaries(
760756
else:
761757
# Initial synthesis
762758
summaries_text = "\n\n".join(
763-
f"### Sub-question: {r['question']}\n"
764-
f"{_truncate(r['summary'], MAX_SUMMARY_CHARS_FOR_SYNTHESIS)}"
759+
f"### Sub-question: {r['question']}\n{_truncate(r['summary'], MAX_SUMMARY_CHARS_FOR_SYNTHESIS)}"
765760
for r in sub_results
766761
)
767762
prompt = (
@@ -790,10 +785,7 @@ async def decompose_query(query: str) -> list[dict[str, str]]:
790785
Returns a list of dicts with "name" and "question" keys.
791786
"""
792787
raw = await run_agent(
793-
prompt=(
794-
f"Decompose this research query into 3-5 specific "
795-
f"sub-questions:\n\n{query}"
796-
),
788+
prompt=(f"Decompose this research query into 3-5 specific sub-questions:\n\n{query}"),
797789
system=DECOMPOSE_SYSTEM_PROMPT,
798790
model="claude-sonnet-4-20250514",
799791
max_tokens=1024,
@@ -808,15 +800,19 @@ async def decompose_query(query: str) -> list[dict[str, str]]:
808800
sub_questions: list[dict[str, str]] = []
809801
for idx, item in enumerate(parsed, start=1):
810802
if isinstance(item, dict):
811-
sub_questions.append({
812-
"name": item.get("name", f"sub-q-{idx}"),
813-
"question": item.get("question", str(item)),
814-
})
803+
sub_questions.append(
804+
{
805+
"name": item.get("name", f"sub-q-{idx}"),
806+
"question": item.get("question", str(item)),
807+
}
808+
)
815809
else:
816-
sub_questions.append({
817-
"name": f"sub-q-{idx}",
818-
"question": str(item),
819-
})
810+
sub_questions.append(
811+
{
812+
"name": f"sub-q-{idx}",
813+
"question": str(item),
814+
}
815+
)
820816
except json.JSONDecodeError:
821817
# Fallback: treat the response as a single question
822818
sub_questions = [{"name": "sub-q-1", "question": raw}]
@@ -847,10 +843,7 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
847843

848844
# --- Step 2: Fan out parallel research sub-agents ---
849845
with flyte.group("parallel-research"):
850-
tasks = [
851-
run_research_sub_agent(sq["question"], sq["name"])
852-
for sq in sub_questions
853-
]
846+
tasks = [run_research_sub_agent(sq["question"], sq["name"]) for sq in sub_questions]
854847
sub_results = list(await asyncio.gather(*tasks))
855848

856849
# --- Step 3: Synthesize into executive summary ---
@@ -864,7 +857,9 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
864857
for refinement_round in range(max_refinements):
865858
with flyte.group(f"evaluate-{refinement_round}"):
866859
evaluation = await evaluate_quality(
867-
query, executive_summary, latest_sub_results,
860+
query,
861+
executive_summary,
862+
latest_sub_results,
868863
)
869864

870865
print(
@@ -886,10 +881,7 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
886881
print("Critic found issues but no follow-up questions — finalizing.")
887882
break
888883

889-
print(
890-
f"Quality check failed (round {refinement_round + 1}). "
891-
f"Critique: {critique}"
892-
)
884+
print(f"Quality check failed (round {refinement_round + 1}). Critique: {critique}")
893885
print(f"Follow-up questions: {follow_ups}")
894886

895887
# --- Step 4a: Targeted follow-up research ---
@@ -917,10 +909,7 @@ async def deep_research_agent(query: str, max_refinements: int = 2) -> str:
917909
critique=critique,
918910
)
919911
else:
920-
print(
921-
f"Reached maximum refinement iterations ({max_refinements}). "
922-
f"Finalizing with best available output."
923-
)
912+
print(f"Reached maximum refinement iterations ({max_refinements}). Finalizing with best available output.")
924913

925914
# --- Step 5: Build and publish final HTML report ---
926915
section_htmls = _build_section_htmls(sub_results)

examples/genai/anthropic_pbj_agent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88
import asyncio
99
from typing import Optional
1010

11-
from flyte._image import PythonWheels, DIST_FOLDER
1211
from flyteplugins.anthropic import function_tool, run_agent
1312

1413
import flyte
14+
from flyte._image import DIST_FOLDER, PythonWheels
1515

1616
agent_env = flyte.TaskEnvironment(
1717
"anthropic-agent",
1818
resources=flyte.Resources(cpu=1),
1919
secrets=[flyte.Secret(key="niels-anthropic-api-key", as_env_var="ANTHROPIC_API_KEY")],
2020
image=(
21-
flyte.Image.from_debian_base(python_version=(3, 13))
22-
.clone(addl_layer=PythonWheels(wheel_dir=DIST_FOLDER, package_name="flyteplugins-anthropic", pre=True))
21+
flyte.Image.from_debian_base(python_version=(3, 13)).clone(
22+
addl_layer=PythonWheels(wheel_dir=DIST_FOLDER, package_name="flyteplugins-anthropic", pre=True)
23+
)
2324
),
2425
)
2526

0 commit comments

Comments
 (0)