|
| 1 | +import pytest |
| 2 | +from unittest.mock import Mock |
| 3 | +from src.agents.visualizations import get_main_graph, get_all_nodes, get_all_edges, draw_graph |
| 4 | +from src.agents.agent import Agent |
| 5 | +import graphviz |
| 6 | + |
| 7 | +@pytest.fixture |
| 8 | +def mock_agent(): |
| 9 | + tool1 = Mock() |
| 10 | + tool1.name = "Tool1" |
| 11 | + tool2 = Mock() |
| 12 | + tool2.name = "Tool2" |
| 13 | + |
| 14 | + handoff1 = Mock() |
| 15 | + handoff1.name = "Handoff1" |
| 16 | + handoff1.tools = [] |
| 17 | + handoff1.handoffs = [] |
| 18 | + |
| 19 | + agent = Mock(spec=Agent) |
| 20 | + agent.name = "Agent1" |
| 21 | + agent.tools = [tool1, tool2] |
| 22 | + agent.handoffs = [handoff1] |
| 23 | + |
| 24 | + return agent |
| 25 | + |
| 26 | +def test_get_main_graph(mock_agent): |
| 27 | + result = get_main_graph(mock_agent) |
| 28 | + assert "digraph G" in result |
| 29 | + assert 'graph [splines=true];' in result |
| 30 | + assert 'node [fontname="Arial"];' in result |
| 31 | + assert 'edge [penwidth=1.5];' in result |
| 32 | + assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result |
| 33 | + assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result |
| 34 | + assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result |
| 35 | + assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result |
| 36 | + assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result |
| 37 | + assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result |
| 38 | + |
| 39 | +def test_get_all_nodes(mock_agent): |
| 40 | + result = get_all_nodes(mock_agent) |
| 41 | + assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result |
| 42 | + assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result |
| 43 | + assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result |
| 44 | + assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result |
| 45 | + |
| 46 | +def test_get_all_edges(mock_agent): |
| 47 | + result = get_all_edges(mock_agent) |
| 48 | + assert '"__start__" -> "Agent1";' in result |
| 49 | + assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result |
| 50 | + assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result |
| 51 | + assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result |
| 52 | + assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result |
| 53 | + assert '"Agent1" -> "Handoff1";' in result |
| 54 | + assert '"Handoff1" -> "__end__";' in result |
| 55 | + |
| 56 | +def test_draw_graph(mock_agent): |
| 57 | + graph = draw_graph(mock_agent) |
| 58 | + assert isinstance(graph, graphviz.Source) |
| 59 | + assert "digraph G" in graph.source |
| 60 | + assert 'graph [splines=true];' in graph.source |
| 61 | + assert 'node [fontname="Arial"];' in graph.source |
| 62 | + assert 'edge [penwidth=1.5];' in graph.source |
| 63 | + assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source |
| 64 | + assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source |
| 65 | + assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source |
| 66 | + assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source |
| 67 | + assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source |
| 68 | + assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source |
0 commit comments