Skip to content

Commit 5ad53d8

Browse files
committed
Add start and end nodes to graph visualization and update edge generation
1 parent 29103ca commit 5ad53d8

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

src/agents/extensions/visualization.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from agents import Agent
66
from agents.handoffs import Handoff
7+
from agents.tool import Tool
78

89

910
def get_main_graph(agent: Agent) -> str:
@@ -41,6 +42,14 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
4142
str: The DOT format string representing the nodes.
4243
"""
4344
parts = []
45+
46+
# Start and end the graph
47+
parts.append(
48+
f'"__start__" [label="__start__", shape=ellipse, style=filled, '
49+
f"fillcolor=lightblue, width=0.5, height=0.3];"
50+
f'"__end__" [label="__end__", shape=ellipse, style=filled, '
51+
f"fillcolor=lightblue, width=0.5, height=0.3];"
52+
)
4453
# Ensure parent agent node is colored
4554
if not parent:
4655
parts.append(
@@ -72,7 +81,7 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
7281
return "".join(parts)
7382

7483

75-
def get_all_edges(agent: Agent) -> str:
84+
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
7685
"""
7786
Recursively generates the edges for the given agent and its handoffs in DOT format.
7887
@@ -85,6 +94,11 @@ def get_all_edges(agent: Agent) -> str:
8594
"""
8695
parts = []
8796

97+
if not parent:
98+
parts.append(
99+
f'"__start__" -> "{agent.name}";'
100+
)
101+
88102
for tool in agent.tools:
89103
parts.append(f"""
90104
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
@@ -97,7 +111,12 @@ def get_all_edges(agent: Agent) -> str:
97111
if isinstance(handoff, Agent):
98112
parts.append(f"""
99113
"{agent.name}" -> "{handoff.name}";""")
100-
parts.append(get_all_edges(handoff))
114+
parts.append(get_all_edges(handoff, agent))
115+
116+
if not agent.handoffs and not isinstance(agent, Tool):
117+
parts.append(
118+
f'"{agent.name}" -> "__end__";'
119+
)
101120

102121
return "".join(parts)
103122

tests/test_visualization.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def test_get_main_graph(mock_agent):
3838
assert "graph [splines=true];" in result
3939
assert 'node [fontname="Arial"];' in result
4040
assert "edge [penwidth=1.5];" in result
41+
assert (
42+
'"__start__" [label="__start__", shape=ellipse, style=filled, '
43+
"fillcolor=lightblue, width=0.5, height=0.3];" in result
44+
)
45+
assert (
46+
'"__end__" [label="__end__", shape=ellipse, style=filled, '
47+
"fillcolor=lightblue, width=0.5, height=0.3];" in result
48+
)
4149
assert (
4250
'"Agent1" [label="Agent1", shape=box, style=filled, '
4351
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
@@ -58,6 +66,14 @@ def test_get_main_graph(mock_agent):
5866

5967
def test_get_all_nodes(mock_agent):
6068
result = get_all_nodes(mock_agent)
69+
assert (
70+
'"__start__" [label="__start__", shape=ellipse, style=filled, '
71+
"fillcolor=lightblue, width=0.5, height=0.3];" in result
72+
)
73+
assert (
74+
'"__end__" [label="__end__", shape=ellipse, style=filled, '
75+
"fillcolor=lightblue, width=0.5, height=0.3];" in result
76+
)
6177
assert (
6278
'"Agent1" [label="Agent1", shape=box, style=filled, '
6379
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
@@ -78,20 +94,31 @@ def test_get_all_nodes(mock_agent):
7894

7995
def test_get_all_edges(mock_agent):
8096
result = get_all_edges(mock_agent)
97+
assert '"__start__" -> "Agent1";' in result
98+
assert '"Agent1" -> "__end__";'
8199
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
82100
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
83101
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
84102
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
85103
assert '"Agent1" -> "Handoff1";' in result
86104

87105

106+
88107
def test_draw_graph(mock_agent):
89108
graph = draw_graph(mock_agent)
90109
assert isinstance(graph, graphviz.Source)
91110
assert "digraph G" in graph.source
92111
assert "graph [splines=true];" in graph.source
93112
assert 'node [fontname="Arial"];' in graph.source
94113
assert "edge [penwidth=1.5];" in graph.source
114+
assert (
115+
'"__start__" [label="__start__", shape=ellipse, style=filled, '
116+
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
117+
)
118+
assert (
119+
'"__end__" [label="__end__", shape=ellipse, style=filled, '
120+
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
121+
)
95122
assert (
96123
'"Agent1" [label="Agent1", shape=box, style=filled, '
97124
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source

0 commit comments

Comments
 (0)