22import asyncio
33from typing import Any , Iterable , Protocol
44from graphlib import TopologicalSorter , CycleError
5+
56from graphai .callback import Callback
67from graphai .utils import logger
78
@@ -64,6 +65,7 @@ def __init__(
6465 self .edges : list [Any ] = []
6566 self .start_node : NodeProtocol | None = None
6667 self .end_nodes : list [NodeProtocol ] = []
68+ self .join_nodes : set [NodeProtocol ] = set ()
6769 self .Callback : type [Callback ] = Callback
6870 self .max_steps = max_steps
6971 self .state = initial_state or {}
@@ -130,6 +132,18 @@ def add_node(self, node: NodeProtocol) -> Graph:
130132 self .end_nodes .append (node )
131133 return self
132134
135+ def _get_node (self , node_candidate : NodeProtocol | str ) -> NodeProtocol :
136+ # first get node from graph
137+ if isinstance (node_candidate , str ):
138+ node = self .nodes .get (node_candidate )
139+ else :
140+ # check if it's a node-like object by looking for required attributes
141+ if hasattr (node_candidate , "name" ):
142+ node = self .nodes .get (node_candidate .name )
143+ if node is None :
144+ raise ValueError (f"Node with name '{ node_candidate } ' not found." )
145+ return node
146+
133147 def add_edge (
134148 self , source : NodeProtocol | str , destination : NodeProtocol | str
135149 ) -> Graph :
@@ -141,33 +155,10 @@ def add_edge(
141155 """
142156 source_node , destination_node = None , None
143157 # get source node from graph
144- source_name : str
145- if isinstance (source , str ):
146- source_node = self .nodes .get (source )
147- source_name = source
148- else :
149- # Check if it's a node-like object by looking for required attributes
150- if hasattr (source , "name" ):
151- source_node = self .nodes .get (source .name )
152- source_name = source .name
153- else :
154- source_name = str (source )
155- if source_node is None :
156- raise ValueError (f"Node with name '{ source_name } ' not found." )
158+ source_node = self ._get_node (node_candidate = source )
157159 # get destination node from graph
158- destination_name : str
159- if isinstance (destination , str ):
160- destination_node = self .nodes .get (destination )
161- destination_name = destination
162- else :
163- # Check if it's a node-like object by looking for required attributes
164- if hasattr (destination , "name" ):
165- destination_node = self .nodes .get (destination .name )
166- destination_name = destination .name
167- else :
168- destination_name = str (destination )
169- if destination_node is None :
170- raise ValueError (f"Node with name '{ destination_name } ' not found." )
160+ destination_node = self ._get_node (node_candidate = destination )
161+ # create edge
171162 edge = Edge (source_node , destination_node )
172163 self .edges .append (edge )
173164 return self
@@ -214,7 +205,6 @@ def compile(self, *, strict: bool = False) -> Graph:
214205 nodes = getattr (self , "nodes" , None )
215206 if not isinstance (nodes , dict ) or not nodes :
216207 raise GraphCompileError ("No nodes have been added to the graph" )
217-
218208 start_name : str | None = None
219209 # Bind and narrow the attribute for mypy
220210 start_node : _HasName | None = getattr (self , "start_node" , None )
@@ -230,21 +220,17 @@ def compile(self, *, strict: bool = False) -> Graph:
230220 raise GraphCompileError (f"Multiple start nodes defined: { starts } " )
231221 if len (starts ) == 1 :
232222 start_name = starts [0 ]
233-
234223 if not start_name :
235224 raise GraphCompileError ("No start node defined" )
236-
237225 # at least one end node
238226 if not any (
239227 getattr (n , "is_end" , False ) or getattr (n , "end" , False )
240228 for n in nodes .values ()
241229 ):
242230 raise GraphCompileError ("No end node defined" )
243-
244231 # normalize edges into adjacency {src: set(dst)}
245232 raw_edges = getattr (self , "edges" , None )
246233 adj : dict [str , set [str ]] = {name : set () for name in nodes .keys ()}
247-
248234 def _add_edge (src : str , dst : str ) -> None :
249235 if src not in nodes :
250236 raise GraphCompileError (f"Edge references unknown source node: { src } " )
@@ -253,7 +239,6 @@ def _add_edge(src: str, dst: str) -> None:
253239 f"Edge from { src } references unknown node(s): ['{ dst } ']"
254240 )
255241 adj [src ].add (dst )
256-
257242 if raw_edges is None :
258243 pass
259244 elif isinstance (raw_edges , dict ):
@@ -273,13 +258,11 @@ def _add_edge(src: str, dst: str) -> None:
273258 iterator = iter (raw_edges )
274259 except TypeError :
275260 raise GraphCompileError ("Internal edge map has unsupported type" )
276-
277261 for item in iterator :
278262 # (src, dst) OR (src, Iterable[dst])
279263 if isinstance (item , (tuple , list )) and len (item ) == 2 :
280264 raw_src , rhs = item
281265 src = _require_name (raw_src , "source" )
282-
283266 if isinstance (rhs , str ) or getattr (rhs , "name" , None ):
284267 dst = _require_name (rhs , "destination" )
285268 _add_edge (src , rhs )
@@ -294,7 +277,6 @@ def _add_edge(src: str, dst: str) -> None:
294277 "Edge tuple second item must be a destination or an iterable of destinations"
295278 )
296279 continue
297-
298280 # Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
299281 if isinstance (item , dict ):
300282 src = _require_name (item .get ("source" , item .get ("src" )), "source" )
@@ -303,7 +285,6 @@ def _add_edge(src: str, dst: str) -> None:
303285 )
304286 _add_edge (src , dst )
305287 continue
306-
307288 # Object with attributes .source/.destination (or .src/.dst)
308289 if hasattr (item , "source" ) or hasattr (item , "src" ):
309290 src = _require_name (
@@ -315,13 +296,11 @@ def _add_edge(src: str, dst: str) -> None:
315296 )
316297 _add_edge (src , dst )
317298 continue
318-
319299 # If none matched, this is an unsupported edge record
320300 raise GraphCompileError (
321301 "Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
322302 "(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
323303 )
324-
325304 # reachability from start
326305 seen : set [str ] = set ()
327306 stack = [start_name ]
@@ -331,11 +310,9 @@ def _add_edge(src: str, dst: str) -> None:
331310 continue
332311 seen .add (cur )
333312 stack .extend (adj .get (cur , ()))
334-
335313 unreachable = sorted (set (nodes .keys ()) - seen )
336314 if unreachable :
337315 raise GraphCompileError (f"Unreachable nodes: { unreachable } " )
338-
339316 # optional cycle detection (strict mode)
340317 if strict :
341318 preds : dict [str , set [str ]] = {n : set () for n in nodes .keys ()}
@@ -346,7 +323,6 @@ def _add_edge(src: str, dst: str) -> None:
346323 list (TopologicalSorter (preds ).static_order ())
347324 except CycleError as e :
348325 raise GraphCompileError ("cycle detected in graph (strict mode)" ) from e
349-
350326 return self
351327
352328 def _validate_output (self , output : dict [str , Any ], node_name : str ):
@@ -358,7 +334,13 @@ def _validate_output(self, output: dict[str, Any], node_name: str):
358334
359335 def _get_next_nodes (self , current_node : NodeProtocol ) -> list [NodeProtocol ]:
360336 """Return all successor nodes for the given node."""
361- return [edge .destination for edge in self .edges if edge .source == current_node ]
337+ # we skip JoinEdge because they don't have regular destinations
338+ # and next nodes for those are handled in the execute method
339+ return [
340+ edge .destination
341+ for edge in self .edges
342+ if isinstance (edge , Edge ) and edge .source == current_node
343+ ]
362344
363345 async def _invoke_node (
364346 self , node : NodeProtocol , state : dict [str , Any ], callback : Callback
@@ -379,6 +361,7 @@ async def _execute_branch(
379361 state : dict [str , Any ],
380362 callback : Callback ,
381363 steps : int ,
364+ stop_at_join : bool = False ,
382365 ):
383366 """Recursively execute a branch starting from `current_node`.
384367 When a node has multiple successors, run them concurrently and merge their outputs."""
@@ -392,6 +375,9 @@ async def _execute_branch(
392375 del output ["choice" ]
393376 current_node = self ._get_node_by_name (node_name = next_node_name )
394377 continue
378+ if stop_at_join and current_node in self .join_nodes :
379+ # for parallel branches, wait at JoinEdge until all branches are complete
380+ return state
395381
396382 next_nodes = self ._get_next_nodes (current_node )
397383 if not next_nodes :
@@ -404,17 +390,43 @@ async def _execute_branch(
404390 # Run each branch concurrently
405391 results = await asyncio .gather (
406392 * [
407- self ._execute_branch (n , state .copy (), callback , steps + 1 )
393+ self ._execute_branch (
394+ current_node = n ,
395+ state = state .copy (),
396+ callback = callback ,
397+ steps = steps + 1 ,
398+ stop_at_join = True , # force parallel branches to wait at JoinEdge
399+ )
408400 for n in next_nodes
409401 ]
410402 )
403+ # merge states returned by each branch
411404 merged = state .copy ()
412405 for res in results :
413- # merge states returned by each branch
414406 for k , v in res .items ():
415407 if k != "callback" :
416408 merged [k ] = v
417- return merged
409+ if set (next_nodes ) & self .join_nodes :
410+ # if any of the next nodes are join nodes, we need to continue from the
411+ # JoinEdge.destination node
412+ join_edge = next (
413+ (
414+ e for e in self .edges if isinstance (e , JoinEdge )
415+ and any (n in e .sources for n in next_nodes )
416+ ),
417+ None
418+ )
419+ if not join_edge :
420+ raise Exception ("No JoinEdge found for next_nodes" )
421+ # set current_node (for next iteration) to the JoinEdge.destination
422+ current_node = join_edge .destination
423+ # continue to the destination node with our merged state
424+ state = merged
425+ continue
426+ else :
427+ # if this happens we have multiple branches that do not join so we
428+ # can just return the merged states
429+ return merged
418430 steps += 1
419431 if steps >= self .max_steps :
420432 raise Exception (
@@ -502,20 +514,46 @@ def _get_node_by_name(self, node_name: str) -> NodeProtocol:
502514
503515 def _get_next_node (self , current_node ):
504516 for edge in self .edges :
505- if edge .source == current_node :
517+ if isinstance ( edge , Edge ) and edge .source == current_node :
506518 return edge .destination
519+ # we skip JoinEdge because they don't have regular destinations
520+ # and next nodes for those are handled in the execute method
507521 raise Exception (
508522 f"No outgoing edge found for current node '{ current_node .name } '."
509523 )
510524
511525 def add_parallel (
512526 self , source : NodeProtocol | str , destinations : list [NodeProtocol | str ]
513527 ):
514- """Add multiple outgoing edges from a single source node to be executed in parallel."""
528+ """Add multiple outgoing edges from a single source node to be executed in parallel.
529+
530+ Args:
531+ source: The source node for the parallel branches.
532+ destinations: The list of destination nodes for the parallel branches.
533+ """
515534 for dest in destinations :
516535 self .add_edge (source , dest )
517536 return self
518537
538+ def add_join (
539+ self , sources : list [NodeProtocol | str ], destination : NodeProtocol | str
540+ ):
541+ """Joins multiple parallel branches into a single branch.
542+
543+ Args:
544+ sources: The list of source nodes for the join.
545+ destination: The destination node for the join.
546+ """
547+ # get source nodes from graph
548+ source_nodes = [self ._get_node (node_candidate = source ) for source in sources ]
549+ # get destination node from graph
550+ destination_node = self ._get_node (node_candidate = destination )
551+ # create join edge
552+ edge = JoinEdge (source_nodes , destination_node )
553+ self .edges .append (edge )
554+ self .join_nodes .update (source_nodes )
555+ return self
556+
519557 def visualize (self , * , save_path : str | None = None ):
520558 """Render the current graph. If matplotlib is not installed,
521559 raise a helpful error telling users to install the viz extra.
@@ -611,3 +649,8 @@ class Edge:
611649 def __init__ (self , source , destination ):
612650 self .source = source
613651 self .destination = destination
652+
653+ class JoinEdge :
654+ def __init__ (self , sources , destination ):
655+ self .sources = sources
656+ self .destination = destination
0 commit comments