Skip to content

Commit 623063b

Browse files
committed
refactor: clean up visualization functions by removing unused nodes and improving type hints
1 parent a5b7abe commit 623063b

File tree

2 files changed

+10
-26
lines changed

2 files changed

+10
-26
lines changed

src/agents/extensions/visualization.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import graphviz
24

35
from agents import Agent
@@ -20,8 +22,6 @@ def get_main_graph(agent: Agent) -> str:
2022
graph [splines=true];
2123
node [fontname="Arial"];
2224
edge [penwidth=1.5];
23-
"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];
24-
"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];
2525
"""
2626
]
2727
parts.append(get_all_nodes(agent))
@@ -30,8 +30,6 @@ def get_main_graph(agent: Agent) -> str:
3030
return "".join(parts)
3131

3232

33-
from typing import Optional
34-
3533
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
3634
"""
3735
Recursively generates the nodes for the given agent and its handoffs in DOT format.
@@ -59,12 +57,14 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
5957
for handoff in agent.handoffs:
6058
if isinstance(handoff, Handoff):
6159
parts.append(
62-
f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, style=filled, style=rounded, '
60+
f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, '
61+
f"shape=box, style=filled, style=rounded, "
6362
f"fillcolor=lightyellow, width=1.5, height=0.8];"
6463
)
6564
if isinstance(handoff, Agent):
6665
parts.append(
67-
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
66+
f'"{handoff.name}" [label="{handoff.name}", '
67+
f"shape=box, style=filled, style=rounded, "
6868
f"fillcolor=lightyellow, width=1.5, height=0.8];"
6969
)
7070
parts.append(get_all_nodes(handoff))
@@ -85,19 +85,11 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
8585
"""
8686
parts = []
8787

88-
if not parent:
89-
parts.append(f"""
90-
"__start__" -> "{agent.name}";""")
91-
9288
for tool in agent.tools:
9389
parts.append(f"""
9490
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
9591
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
9692

97-
if not agent.handoffs:
98-
parts.append(f"""
99-
"{agent.name}" -> "__end__";""")
100-
10193
for handoff in agent.handoffs:
10294
if isinstance(handoff, Handoff):
10395
parts.append(f"""
@@ -110,8 +102,6 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
110102
return "".join(parts)
111103

112104

113-
from typing import Optional
114-
115105
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
116106
"""
117107
Draws the graph for the given agent and optionally saves it as a PNG file.

tests/test_visualization.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
get_all_nodes,
1111
get_main_graph,
1212
)
13+
from agents.handoffs import Handoff
1314

1415

1516
@pytest.fixture
@@ -19,10 +20,8 @@ def mock_agent():
1920
tool2 = Mock()
2021
tool2.name = "Tool2"
2122

22-
handoff1 = Mock()
23-
handoff1.name = "Handoff1"
24-
handoff1.tools = []
25-
handoff1.handoffs = []
23+
handoff1 = Mock(spec=Handoff)
24+
handoff1.agent_name = "Handoff1"
2625

2726
agent = Mock(spec=Agent)
2827
agent.name = "Agent1"
@@ -34,12 +33,11 @@ def mock_agent():
3433

3534
def test_get_main_graph(mock_agent):
3635
result = get_main_graph(mock_agent)
36+
print(result)
3737
assert "digraph G" in result
3838
assert "graph [splines=true];" in result
3939
assert 'node [fontname="Arial"];' in result
4040
assert "edge [penwidth=1.5];" in result
41-
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
42-
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
4341
assert (
4442
'"Agent1" [label="Agent1", shape=box, style=filled, '
4543
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
@@ -80,13 +78,11 @@ def test_get_all_nodes(mock_agent):
8078

8179
def test_get_all_edges(mock_agent):
8280
result = get_all_edges(mock_agent)
83-
assert '"__start__" -> "Agent1";' in result
8481
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
8582
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
8683
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
8784
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
8885
assert '"Agent1" -> "Handoff1";' in result
89-
assert '"Handoff1" -> "__end__";' in result
9086

9187

9288
def test_draw_graph(mock_agent):
@@ -96,8 +92,6 @@ def test_draw_graph(mock_agent):
9692
assert "graph [splines=true];" in graph.source
9793
assert 'node [fontname="Arial"];' in graph.source
9894
assert "edge [penwidth=1.5];" in graph.source
99-
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
100-
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
10195
assert (
10296
'"Agent1" [label="Agent1", shape=box, style=filled, '
10397
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source

0 commit comments

Comments
 (0)