@@ -124,38 +124,36 @@ def is_any(t: Any) -> bool:
124124
125125
126126def are_connection_types_compatible (from_type : Any , to_type : Any ) -> bool :
127- if not from_type :
128- return False
129- if not to_type :
127+ if not from_type or not to_type :
130128 return False
131129
132- # TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
133- if from_type and to_type :
134- # Ports are compatible
135- if from_type == to_type or is_any (from_type ) or is_any (to_type ):
136- return True
130+ # Ports are compatible
131+ if from_type == to_type or is_any (from_type ) or is_any (to_type ):
132+ return True
137133
138- if from_type in get_args (to_type ):
139- return True
134+ if from_type in get_args (to_type ):
135+ return True
140136
141- if to_type in get_args (from_type ):
142- return True
137+ if to_type in get_args (from_type ):
138+ return True
143139
144- # allow int -> float, pydantic will cast for us
145- if from_type is int and to_type is float :
146- return True
140+ # allow int -> float, pydantic will cast for us
141+ if from_type is int and to_type is float :
142+ return True
147143
148- # allow int|float -> str, pydantic will cast for us
149- if (from_type is int or from_type is float ) and to_type is str :
150- return True
144+ # allow int|float -> str, pydantic will cast for us
145+ if (from_type is int or from_type is float ) and to_type is str :
146+ return True
151147
152- # if not issubclass(from_type, to_type):
153- if not is_union_subtype (from_type , to_type ):
154- return False
155- else :
156- return False
148+ # Prefer issubclass when both are real classes
149+ try :
150+ if isinstance (from_type , type ) and isinstance (to_type , type ):
151+ return issubclass (from_type , to_type )
152+ except TypeError :
153+ pass
157154
158- return True
155+ # Union-to-Union (or Union-to-non-Union) handling
156+ return is_union_subtype (from_type , to_type )
159157
160158
161159def are_connections_compatible (
@@ -654,6 +652,9 @@ def _is_iterator_connection_valid(
654652 if new_output is not None :
655653 outputs .append (new_output )
656654
655+ if len (inputs ) == 0 :
656+ return "Iterator must have a collection input edge"
657+
657658 # Only one input is allowed for iterators
658659 if len (inputs ) > 1 :
659660 return "Iterator may only have one input edge"
@@ -675,9 +676,13 @@ def _is_iterator_connection_valid(
675676
676677 # Collector input type must match all iterator output types
677678 if isinstance (input_node , CollectInvocation ):
679+ collector_inputs = self ._get_input_edges (input_node .id , ITEM_FIELD )
680+ if len (collector_inputs ) == 0 :
681+ return "Iterator input collector must have at least one item input edge"
682+
678683 # Traverse the graph to find the first collector input edge. Collectors validate that their collection
679684 # inputs are all of the same type, so we can use the first input edge to determine the collector's type
680- first_collector_input_edge = self . _get_input_edges ( input_node . id , ITEM_FIELD ) [0 ]
685+ first_collector_input_edge = collector_inputs [0 ]
681686 first_collector_input_type = get_output_field_type (
682687 self .get_node (first_collector_input_edge .source .node_id ), first_collector_input_edge .source .field
683688 )
@@ -751,21 +756,12 @@ def nx_graph(self) -> nx.DiGraph:
751756 g .add_edges_from ({(e .source .node_id , e .destination .node_id ) for e in self .edges })
752757 return g
753758
754- def nx_graph_with_data (self ) -> nx .DiGraph :
755- """Returns a NetworkX DiGraph representing the data and layout of this graph"""
756- g = nx .DiGraph ()
757- g .add_nodes_from (list (self .nodes .items ()))
758- g .add_edges_from ({(e .source .node_id , e .destination .node_id ) for e in self .edges })
759- return g
760-
761759 def nx_graph_flat (self , nx_graph : Optional [nx .DiGraph ] = None ) -> nx .DiGraph :
762760 """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
763761 g = nx_graph or nx .DiGraph ()
764762
765763 # Add all nodes from this graph except graph/iteration nodes
766- g .add_nodes_from ([n .id for n in self .nodes .values () if not isinstance (n , IterateInvocation )])
767-
768- # TODO: figure out if iteration nodes need to be expanded
764+ g .add_nodes_from ([n .id for n in self .nodes .values ()])
769765
770766 unique_edges = {(e .source .node_id , e .destination .node_id ) for e in self .edges }
771767 g .add_edges_from (unique_edges )
@@ -816,10 +812,57 @@ class GraphExecutionState(BaseModel):
816812 # Optional priority; others follow in name order
817813 ready_order : list [str ] = Field (default_factory = list )
818814 indegree : dict [str , int ] = Field (default_factory = dict , description = "Remaining unmet input count for exec nodes" )
815+ _iteration_path_cache : dict [str , tuple [int , ...]] = PrivateAttr (default_factory = dict )
819816
820817 def _type_key (self , node_obj : BaseInvocation ) -> str :
821818 return node_obj .__class__ .__name__
822819
820+ def _get_iteration_path (self , exec_node_id : str ) -> tuple [int , ...]:
821+ """Best-effort outer->inner iteration indices for an execution node, stopping at collectors."""
822+ cached = self ._iteration_path_cache .get (exec_node_id )
823+ if cached is not None :
824+ return cached
825+
826+ # Only prepared execution nodes participate; otherwise treat as non-iterated.
827+ source_node_id = self .prepared_source_mapping .get (exec_node_id )
828+ if source_node_id is None :
829+ self ._iteration_path_cache [exec_node_id ] = ()
830+ return ()
831+
832+ # Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak.
833+ it_g = self ._iterator_graph (self .graph .nx_graph ())
834+ iterator_sources = [
835+ n for n in nx .ancestors (it_g , source_node_id ) if isinstance (self .graph .get_node (n ), IterateInvocation )
836+ ]
837+
838+ # Order iterators outer->inner via topo order of the iterator graph.
839+ topo = list (nx .topological_sort (it_g ))
840+ topo_index = {n : i for i , n in enumerate (topo )}
841+ iterator_sources .sort (key = lambda n : topo_index .get (n , 0 ))
842+
843+ # Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id.
844+ eg = self .execution_graph .nx_graph ()
845+ path : list [int ] = []
846+ for it_src in iterator_sources :
847+ prepared = self .source_prepared_mapping .get (it_src )
848+ if not prepared :
849+ continue
850+ it_exec = next ((p for p in prepared if nx .has_path (eg , p , exec_node_id )), None )
851+ if it_exec is None :
852+ continue
853+ it_node = self .execution_graph .nodes .get (it_exec )
854+ if isinstance (it_node , IterateInvocation ):
855+ path .append (it_node .index )
856+
857+ # If this exec node is itself an iterator, include its own index as the innermost element.
858+ node_obj = self .execution_graph .nodes .get (exec_node_id )
859+ if isinstance (node_obj , IterateInvocation ):
860+ path .append (node_obj .index )
861+
862+ result = tuple (path )
863+ self ._iteration_path_cache [exec_node_id ] = result
864+ return result
865+
823866 def _queue_for (self , cls_name : str ) -> Deque [str ]:
824867 q = self ._ready_queues .get (cls_name )
825868 if q is None :
@@ -843,7 +886,15 @@ def _enqueue_if_ready(self, nid: str) -> None:
843886 if self .indegree [nid ] != 0 or nid in self .executed :
844887 return
845888 node_obj = self .execution_graph .nodes [nid ]
846- self ._queue_for (self ._type_key (node_obj )).append (nid )
889+ q = self ._queue_for (self ._type_key (node_obj ))
890+ nid_path = self ._get_iteration_path (nid )
891+ # Insert in lexicographic outer->inner order; preserve FIFO for equal paths.
892+ for i , existing in enumerate (q ):
893+ if self ._get_iteration_path (existing ) > nid_path :
894+ q .insert (i , nid )
895+ break
896+ else :
897+ q .append (nid )
847898
848899 model_config = ConfigDict (
849900 json_schema_extra = {
@@ -1083,12 +1134,12 @@ def no_unexecuted_iter_ancestors(n: str) -> bool:
10831134
10841135 # Select the correct prepared parents for each iteration
10851136 # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
1086- # TODO: Handle a node mapping to none
10871137 eg = self .execution_graph .nx_graph_flat ()
10881138 prepared_parent_mappings = [
10891139 [(n , self ._get_iteration_node (n , g , eg , it )) for n in next_node_parents ]
10901140 for it in iterator_node_prepared_combinations
10911141 ] # type: ignore
1142+ prepared_parent_mappings = [m for m in prepared_parent_mappings if all (p [1 ] is not None for p in m )]
10921143
10931144 # Create execution node for each iteration
10941145 for iteration_mappings in prepared_parent_mappings :
@@ -1110,15 +1161,17 @@ def _get_iteration_node(
11101161 if len (prepared_nodes ) == 1 :
11111162 return next (iter (prepared_nodes ))
11121163
1113- # Check if the requested node is an iterator
1114- prepared_iterator = next ((n for n in prepared_nodes if n in prepared_iterator_nodes ), None )
1115- if prepared_iterator is not None :
1116- return prepared_iterator
1117-
11181164 # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
11191165 iterator_source_node_mapping = [(n , self .prepared_source_mapping [n ]) for n in prepared_iterator_nodes ]
11201166 parent_iterators = [itn for itn in iterator_source_node_mapping if nx .has_path (graph , itn [1 ], source_node_id )]
11211167
1168+ # If the requested node is an iterator, only accept it if it is compatible with all parent iterators
1169+ prepared_iterator = next ((n for n in prepared_nodes if n in prepared_iterator_nodes ), None )
1170+ if prepared_iterator is not None :
1171+ if all (nx .has_path (execution_graph , pit [0 ], prepared_iterator ) for pit in parent_iterators ):
1172+ return prepared_iterator
1173+ return None
1174+
11221175 return next (
11231176 (n for n in prepared_nodes if all (nx .has_path (execution_graph , pit [0 ], n ) for pit in parent_iterators )),
11241177 None ,
@@ -1156,11 +1209,10 @@ def _prepare_inputs(self, node: BaseInvocation):
11561209 # Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
11571210 # will see the mutation.
11581211 if isinstance (node , CollectInvocation ):
1159- output_collection = [
1160- copydeep (getattr (self .results [edge .source .node_id ], edge .source .field ))
1161- for edge in input_edges
1162- if edge .destination .field == ITEM_FIELD
1163- ]
1212+ item_edges = [e for e in input_edges if e .destination .field == ITEM_FIELD ]
1213+ item_edges .sort (key = lambda e : (self ._get_iteration_path (e .source .node_id ), e .source .node_id ))
1214+
1215+ output_collection = [copydeep (getattr (self .results [e .source .node_id ], e .source .field )) for e in item_edges ]
11641216 node .collection = output_collection
11651217 else :
11661218 for edge in input_edges :
0 commit comments