66
77class Graph :
88 def __init__ (self , max_steps : int = 10 ):
9- self .nodes = []
9+ self .nodes = {}
1010 self .edges = []
1111 self .start_node = None
1212 self .end_nodes = []
@@ -15,7 +15,9 @@ def __init__(self, max_steps: int = 10):
1515 self .max_steps = max_steps
1616
1717 def add_node (self , node ):
18- self .nodes .append (node )
18+ if node .name in self .nodes :
19+ raise Exception (f"Node with name '{ node .name } ' already exists." )
20+ self .nodes [node .name ] = node
1921 if node .is_start :
2022 if self .start_node is not None :
2123 raise Exception (
@@ -27,10 +29,37 @@ def add_node(self, node):
2729 if node .is_end :
2830 self .end_nodes .append (node )
2931
30- def add_edge (self , source : _Node , destination : _Node ):
31- # TODO add logic to check that source and destination are nodes
32- # and they exist in the graph object already
33- edge = Edge (source , destination )
32+ def add_edge (self , source : _Node | str , destination : _Node | str ):
33+ """Adds an edge between two nodes that already exist in the graph.
34+
35+ Args:
36+ source: The source node or its name.
37+ destination: The destination node or its name.
38+ """
39+ source_node , destination_node = None , None
40+ # get source node from graph
41+ if isinstance (source , str ):
42+ source_node = self .nodes .get (source )
43+ else :
44+ # Check if it's a node-like object by looking for required attributes
45+ if hasattr (source , 'name' ):
46+ source_node = self .nodes .get (source .name )
47+ if source_node is None :
48+ raise ValueError (
49+ f"Node with name '{ source .name if hasattr (source , 'name' ) else source } ' not found."
50+ )
51+ # get destination node from graph
52+ if isinstance (destination , str ):
53+ destination_node = self .nodes .get (destination )
54+ else :
55+ # Check if it's a node-like object by looking for required attributes
56+ if hasattr (destination , 'name' ):
57+ destination_node = self .nodes .get (destination .name )
58+ if destination_node is None :
59+ raise ValueError (
60+ f"Node with name '{ destination .name if hasattr (destination , 'name' ) else destination } ' not found."
61+ )
62+ edge = Edge (source_node , destination_node )
3463 self .edges .append (edge )
3564
3665 def add_router (self , sources : list [_Node ], router : _Node , destinations : List [_Node ]):
@@ -139,7 +168,7 @@ def visualize(self):
139168
140169 G = nx .DiGraph ()
141170
142- for node in self .nodes :
171+ for node in self .nodes . values () :
143172 G .add_node (node .name )
144173
145174 for edge in self .edges :
@@ -173,10 +202,11 @@ def visualize(self):
173202 pos [node ] = (pos [node ][0 ] - x_center , pos [node ][1 ])
174203
175204 # Scale the layout
176- max_x = max (abs (p [0 ]) for p in pos .values ())
177- max_y = max (abs (p [1 ]) for p in pos .values ())
178- scale = min (0.8 / max_x , 0.8 / max_y )
179- pos = {node : (x * scale , y * scale ) for node , (x , y ) in pos .items ()}
205+ max_x = max (abs (p [0 ]) for p in pos .values ()) if pos else 1
206+ max_y = max (abs (p [1 ]) for p in pos .values ()) if pos else 1
207+ if max_x > 0 and max_y > 0 :
208+ scale = min (0.8 / max_x , 0.8 / max_y )
209+ pos = {node : (x * scale , y * scale ) for node , (x , y ) in pos .items ()}
180210
181211 else :
182212 print ("Warning: The graph contains cycles. Visualization will use a spring layout." )
0 commit comments