@@ -341,7 +341,7 @@ def join(
341341 return join (reducer_type = reducer_factory , node_id = node_id )
342342
343343 # Edge building
344- def add (self , * edges : EdgePath [StateT , DepsT ]) -> None :
344+ def add (self , * edges : EdgePath [StateT , DepsT ]) -> None : # noqa C901
345345 """Add one or more edge paths to the graph.
346346
347347 This method processes edge paths and automatically creates any necessary
@@ -359,12 +359,12 @@ def _handle_path(p: Path):
359359 """
360360 for item in p .items :
361361 if isinstance (item , BroadcastMarker ):
362- new_node = Fork [Any , Any ](id = item .fork_id , is_map = False )
362+ new_node = Fork [Any , Any ](id = item .fork_id , is_map = False , downstream_join_id = None )
363363 self ._insert_node (new_node )
364364 for path in item .paths :
365365 _handle_path (Path (items = [* path .items ]))
366366 elif isinstance (item , MapMarker ):
367- new_node = Fork [Any , Any ](id = item .fork_id , is_map = True )
367+ new_node = Fork [Any , Any ](id = item .fork_id , is_map = True , downstream_join_id = item . downstream_join_id )
368368 self ._insert_node (new_node )
369369 elif isinstance (item , DestinationMarker ):
370370 pass
@@ -710,6 +710,7 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
710710 # TODO(P3): Consider doing a deepcopy here to prevent modifications to the underlying nodes and edges
711711 nodes = self ._nodes
712712 edges_by_source = self ._edges_by_source
713+ nodes , edges_by_source = _flatten_paths (nodes , edges_by_source )
713714 nodes , edges_by_source = _normalize_forks (nodes , edges_by_source )
714715 parent_forks = _collect_dominating_forks (nodes , edges_by_source )
715716
@@ -726,6 +727,52 @@ def build(self) -> Graph[StateT, DepsT, GraphInputT, GraphOutputT]:
726727 )
727728
728729
730+ def _flatten_paths (
731+ nodes : dict [NodeID , AnyNode ], edges : dict [NodeID , list [Path ]]
732+ ) -> tuple [dict [NodeID , AnyNode ], dict [NodeID , list [Path ]]]:
733+ new_nodes = nodes .copy ()
734+ new_edges : dict [NodeID , list [Path ]] = defaultdict (list )
735+
736+ paths_to_handle : list [tuple [NodeID , Path ]] = []
737+
738+ def _split_at_first_fork (path : Path ) -> tuple [Path , list [tuple [NodeID , Path ]]]:
739+ for i , item in enumerate (path .items ):
740+ if isinstance (item , MapMarker ):
741+ if item .fork_id not in nodes :
742+ new_nodes [item .fork_id ] = Fork (
743+ id = item .fork_id , is_map = True , downstream_join_id = item .downstream_join_id
744+ )
745+ upstream = Path (list (path .items [:i ]) + [DestinationMarker (item .fork_id )])
746+ downstream = Path (path .items [i + 1 :])
747+ return upstream , [(item .fork_id , downstream )]
748+
749+ if isinstance (item , BroadcastMarker ):
750+ if item .fork_id not in nodes :
751+ new_nodes [item .fork_id ] = Fork (id = item .fork_id , is_map = True , downstream_join_id = None )
752+ upstream = Path (list (path .items [:i ]) + [DestinationMarker (item .fork_id )])
753+ return upstream , [(item .fork_id , p ) for p in item .paths ]
754+ return path , []
755+
756+ for node in new_nodes .values ():
757+ if isinstance (node , Decision ):
758+ for branch in node .branches :
759+ upstream , downstreams = _split_at_first_fork (branch .path )
760+ branch .path = upstream
761+ paths_to_handle .extend (downstreams )
762+
763+ for source_id , edges_from_source in edges .items ():
764+ for path in edges_from_source :
765+ paths_to_handle .append ((source_id , path ))
766+
767+ while paths_to_handle :
768+ source_id , path = paths_to_handle .pop ()
769+ upstream , downstreams = _split_at_first_fork (path )
770+ new_edges [source_id ].append (upstream )
771+ paths_to_handle .extend (downstreams )
772+
773+ return new_nodes , dict (new_edges )
774+
775+
729776def _normalize_forks (
730777 nodes : dict [NodeID , AnyNode ], edges : dict [NodeID , list [Path ]]
731778) -> tuple [dict [NodeID , AnyNode ], dict [NodeID , list [Path ]]]:
@@ -756,25 +803,11 @@ def _normalize_forks(
756803 if len (edges_from_source ) == 1 :
757804 new_edges [source_id ] = edges_from_source
758805 continue
759- new_fork = Fork [Any , Any ](id = ForkID (NodeID (f'{ node .id } _broadcast_fork' )), is_map = False )
806+ new_fork = Fork [Any , Any ](id = ForkID (NodeID (f'{ node .id } _broadcast_fork' )), is_map = False , downstream_join_id = None )
760807 new_nodes [new_fork .id ] = new_fork
761808 new_edges [source_id ] = [Path (items = [BroadcastMarker (fork_id = new_fork .id , paths = edges_from_source )])]
762809 new_edges [new_fork .id ] = edges_from_source
763810
764- while paths_to_handle :
765- path = paths_to_handle .pop ()
766- for item in path .items :
767- if isinstance (item , MapMarker ):
768- assert item .fork_id in new_nodes
769- new_edges [item .fork_id ] = [path .next_path ]
770- paths_to_handle .append (path .next_path )
771- break
772- elif isinstance (item , BroadcastMarker ):
773- assert item .fork_id in new_nodes
774- new_edges [item .fork_id ] = [* item .paths ]
775- paths_to_handle .extend (item .paths )
776- break
777-
778811 return new_nodes , new_edges
779812
780813
@@ -808,7 +841,6 @@ def _collect_dominating_forks(
808841
809842 if isinstance (node , Fork ):
810843 fork_ids .add (node .id )
811- continue
812844
813845 def _handle_path (path : Path , last_source_id : NodeID ):
814846 """Process a path and collect edges and fork information.
0 commit comments