Skip to content

Commit cecdcd0

Browse files
committed
Add visualization functions for agents using Graphviz
1 parent 5865c6f commit cecdcd0

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

src/agents/visualizations.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import graphviz
2+
3+
from src.agents.agent import Agent
4+
5+
6+
def get_main_graph(agent: Agent) -> str:
7+
"""
8+
Generates the main graph structure in DOT format for the given agent.
9+
10+
Args:
11+
agent (Agent): The agent for which the graph is to be generated.
12+
13+
Returns:
14+
str: The DOT format string representing the graph.
15+
"""
16+
parts = ["""
17+
digraph G {
18+
graph [splines=true];
19+
node [fontname="Arial"];
20+
edge [penwidth=1.5];
21+
22+
"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];
23+
"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];
24+
"""]
25+
parts.append(get_all_nodes(agent))
26+
parts.append(get_all_edges(agent))
27+
parts.append("}")
28+
return "".join(parts)
29+
30+
31+
def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
32+
"""
33+
Recursively generates the nodes for the given agent and its handoffs in DOT format.
34+
35+
Args:
36+
agent (Agent): The agent for which the nodes are to be generated.
37+
38+
Returns:
39+
str: The DOT format string representing the nodes.
40+
"""
41+
parts = []
42+
43+
# Ensure parent agent node is colored
44+
if not parent:
45+
parts.append(f"""
46+
"{agent.name}" [label="{agent.name}", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];""")
47+
48+
# Smaller tools (ellipse, green)
49+
for tool in agent.tools:
50+
parts.append(f"""
51+
"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];""")
52+
53+
# Bigger handoffs (rounded box, yellow)
54+
for handoff in agent.handoffs:
55+
parts.append(f"""
56+
"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""")
57+
parts.append(get_all_nodes(handoff))
58+
59+
return "".join(parts)
60+
61+
62+
def get_all_edges(agent: Agent, parent: Agent = None) -> str:
63+
"""
64+
Recursively generates the edges for the given agent and its handoffs in DOT format.
65+
66+
Args:
67+
agent (Agent): The agent for which the edges are to be generated.
68+
parent (Agent, optional): The parent agent. Defaults to None.
69+
70+
Returns:
71+
str: The DOT format string representing the edges.
72+
"""
73+
parts = []
74+
75+
if not parent:
76+
parts.append(f"""
77+
"__start__" -> "{agent.name}";""")
78+
79+
for tool in agent.tools:
80+
parts.append(f"""
81+
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
82+
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
83+
84+
if not agent.handoffs:
85+
parts.append(f"""
86+
"{agent.name}" -> "__end__";""")
87+
88+
for handoff in agent.handoffs:
89+
parts.append(f"""
90+
"{agent.name}" -> "{handoff.name}";""")
91+
parts.append(get_all_edges(handoff, agent))
92+
93+
return "".join(parts)
94+
95+
96+
def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source:
97+
"""
98+
Draws the graph for the given agent and optionally saves it as a PNG file.
99+
100+
Args:
101+
agent (Agent): The agent for which the graph is to be drawn.
102+
filename (str): The name of the file to save the graph as a PNG.
103+
104+
Returns:
105+
graphviz.Source: The graphviz Source object representing the graph.
106+
"""
107+
dot_code = get_main_graph(agent)
108+
graph = graphviz.Source(dot_code)
109+
110+
if filename:
111+
graph.render(filename, format='png')
112+
113+
return graph

0 commit comments

Comments
 (0)