1
1
import graphviz
2
2
3
3
from agents import Agent
4
+ from agents .handoffs import Handoff
4
5
5
6
6
7
def get_main_graph (agent : Agent ) -> str :
@@ -29,7 +30,9 @@ def get_main_graph(agent: Agent) -> str:
29
30
return "" .join (parts )
30
31
31
32
32
- def get_all_nodes (agent : Agent , parent : Agent = None ) -> str :
33
+ from typing import Optional
34
+
35
+ def get_all_nodes (agent : Agent , parent : Optional [Agent ] = None ) -> str :
33
36
"""
34
37
Recursively generates the nodes for the given agent and its handoffs in DOT format.
35
38
@@ -47,25 +50,29 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
47
50
"fillcolor=lightyellow, width=1.5, height=0.8];"
48
51
)
49
52
50
- # Smaller tools (ellipse, green)
51
53
for tool in agent .tools :
52
54
parts .append (
53
55
f'"{ tool .name } " [label="{ tool .name } ", shape=ellipse, style=filled, '
54
56
f"fillcolor=lightgreen, width=0.5, height=0.3];"
55
57
)
56
58
57
- # Bigger handoffs (rounded box, yellow)
58
59
for handoff in agent .handoffs :
59
- parts .append (
60
- f'"{ handoff .name } " [label="{ handoff .name } ", shape=box, style=filled, style=rounded, '
61
- f"fillcolor=lightyellow, width=1.5, height=0.8];"
62
- )
63
- parts .append (get_all_nodes (handoff ))
60
+ if isinstance (handoff , Handoff ):
61
+ parts .append (
62
+ f'"{ handoff .agent_name } " [label="{ handoff .agent_name } ", shape=box, style=filled, style=rounded, '
63
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
64
+ )
65
+ if isinstance (handoff , Agent ):
66
+ parts .append (
67
+ f'"{ handoff .name } " [label="{ handoff .name } ", shape=box, style=filled, style=rounded, '
68
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
69
+ )
70
+ parts .append (get_all_nodes (handoff ))
64
71
65
72
return "" .join (parts )
66
73
67
74
68
- def get_all_edges (agent : Agent , parent : Agent = None ) -> str :
75
+ def get_all_edges (agent : Agent , parent : Optional [ Agent ] = None ) -> str :
69
76
"""
70
77
Recursively generates the edges for the given agent and its handoffs in DOT format.
71
78
@@ -92,14 +99,20 @@ def get_all_edges(agent: Agent, parent: Agent = None) -> str:
92
99
"{ agent .name } " -> "__end__";""" )
93
100
94
101
for handoff in agent .handoffs :
95
- parts .append (f"""
96
- "{ agent .name } " -> "{ handoff .name } ";""" )
97
- parts .append (get_all_edges (handoff , agent ))
102
+ if isinstance (handoff , Handoff ):
103
+ parts .append (f"""
104
+ "{ agent .name } " -> "{ handoff .agent_name } ";""" )
105
+ if isinstance (handoff , Agent ):
106
+ parts .append (f"""
107
+ "{ agent .name } " -> "{ handoff .name } ";""" )
108
+ parts .append (get_all_edges (handoff , agent ))
98
109
99
110
return "" .join (parts )
100
111
101
112
102
- def draw_graph (agent : Agent , filename : str = None ) -> graphviz .Source :
113
+ from typing import Optional
114
+
115
+ def draw_graph (agent : Agent , filename : Optional [str ] = None ) -> graphviz .Source :
103
116
"""
104
117
Draws the graph for the given agent and optionally saves it as a PNG file.
105
118
0 commit comments