Skip to content

Commit 9b972b3

Browse files
committed
Add unit tests for visualization functions in test_visualizations.py
1 parent cecdcd0 commit 9b972b3

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

tests/test_visualizations.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
from unittest.mock import Mock
3+
from src.agents.visualizations import get_main_graph, get_all_nodes, get_all_edges, draw_graph
4+
from src.agents.agent import Agent
5+
import graphviz
6+
7+
@pytest.fixture
8+
def mock_agent():
9+
tool1 = Mock()
10+
tool1.name = "Tool1"
11+
tool2 = Mock()
12+
tool2.name = "Tool2"
13+
14+
handoff1 = Mock()
15+
handoff1.name = "Handoff1"
16+
handoff1.tools = []
17+
handoff1.handoffs = []
18+
19+
agent = Mock(spec=Agent)
20+
agent.name = "Agent1"
21+
agent.tools = [tool1, tool2]
22+
agent.handoffs = [handoff1]
23+
24+
return agent
25+
26+
def test_get_main_graph(mock_agent):
27+
result = get_main_graph(mock_agent)
28+
assert "digraph G" in result
29+
assert 'graph [splines=true];' in result
30+
assert 'node [fontname="Arial"];' in result
31+
assert 'edge [penwidth=1.5];' in result
32+
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
33+
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
34+
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result
35+
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
36+
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
37+
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result
38+
39+
def test_get_all_nodes(mock_agent):
40+
result = get_all_nodes(mock_agent)
41+
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result
42+
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
43+
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
44+
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result
45+
46+
def test_get_all_edges(mock_agent):
47+
result = get_all_edges(mock_agent)
48+
assert '"__start__" -> "Agent1";' in result
49+
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
50+
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
51+
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
52+
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
53+
assert '"Agent1" -> "Handoff1";' in result
54+
assert '"Handoff1" -> "__end__";' in result
55+
56+
def test_draw_graph(mock_agent):
57+
graph = draw_graph(mock_agent)
58+
assert isinstance(graph, graphviz.Source)
59+
assert "digraph G" in graph.source
60+
assert 'graph [splines=true];' in graph.source
61+
assert 'node [fontname="Arial"];' in graph.source
62+
assert 'edge [penwidth=1.5];' in graph.source
63+
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
64+
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
65+
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source
66+
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source
67+
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source
68+
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source

0 commit comments

Comments
 (0)