Skip to content

Commit a5b7abe

Browse files
committed
feat: enhance visualization functions with optional type hints and improved handling of agents and handoffs
1 parent 9f7d596 commit a5b7abe

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

src/agents/extensions/visualization.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import graphviz
22

33
from agents import Agent
4+
from agents.handoffs import Handoff
45

56

67
def get_main_graph(agent: Agent) -> str:
@@ -29,7 +30,9 @@ def get_main_graph(agent: Agent) -> str:
2930
return "".join(parts)
3031

3132

32-
def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
33+
from typing import Optional
34+
35+
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
3336
"""
3437
Recursively generates the nodes for the given agent and its handoffs in DOT format.
3538
@@ -47,25 +50,29 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
4750
"fillcolor=lightyellow, width=1.5, height=0.8];"
4851
)
4952

50-
# Smaller tools (ellipse, green)
5153
for tool in agent.tools:
5254
parts.append(
5355
f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
5456
f"fillcolor=lightgreen, width=0.5, height=0.3];"
5557
)
5658

57-
# Bigger handoffs (rounded box, yellow)
5859
for handoff in agent.handoffs:
59-
parts.append(
60-
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
61-
f"fillcolor=lightyellow, width=1.5, height=0.8];"
62-
)
63-
parts.append(get_all_nodes(handoff))
60+
if isinstance(handoff, Handoff):
61+
parts.append(
62+
f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, style=filled, style=rounded, '
63+
f"fillcolor=lightyellow, width=1.5, height=0.8];"
64+
)
65+
if isinstance(handoff, Agent):
66+
parts.append(
67+
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
68+
f"fillcolor=lightyellow, width=1.5, height=0.8];"
69+
)
70+
parts.append(get_all_nodes(handoff))
6471

6572
return "".join(parts)
6673

6774

68-
def get_all_edges(agent: Agent, parent: Agent = None) -> str:
75+
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
6976
"""
7077
Recursively generates the edges for the given agent and its handoffs in DOT format.
7178
@@ -92,14 +99,20 @@ def get_all_edges(agent: Agent, parent: Agent = None) -> str:
9299
"{agent.name}" -> "__end__";""")
93100

94101
for handoff in agent.handoffs:
95-
parts.append(f"""
96-
"{agent.name}" -> "{handoff.name}";""")
97-
parts.append(get_all_edges(handoff, agent))
102+
if isinstance(handoff, Handoff):
103+
parts.append(f"""
104+
"{agent.name}" -> "{handoff.agent_name}";""")
105+
if isinstance(handoff, Agent):
106+
parts.append(f"""
107+
"{agent.name}" -> "{handoff.name}";""")
108+
parts.append(get_all_edges(handoff, agent))
98109

99110
return "".join(parts)
100111

101112

102-
def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source:
113+
from typing import Optional
114+
115+
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
103116
"""
104117
Draws the graph for the given agent and optionally saves it as a PNG file.
105118

0 commit comments

Comments
 (0)