-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Implemented ordering for expanded iterators #8741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
f9d38c1
fb8de3f
526d082
9c5ee7a
5bcbafb
8e0d092
f75d715
071b6c0
453e02b
5a1dcfc
6f9f717
2783198
07ca01f
3291fec
fb77e84
365220a
af580ce
a1999b9
d961698
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -816,10 +816,57 @@ class GraphExecutionState(BaseModel): | |
| # Optional priority; others follow in name order | ||
| ready_order: list[str] = Field(default_factory=list) | ||
| indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes") | ||
| _iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict) | ||
|
|
||
| def _type_key(self, node_obj: BaseInvocation) -> str: | ||
| return node_obj.__class__.__name__ | ||
|
|
||
| def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]: | ||
| """Best-effort outer->inner iteration indices for an execution node, stopping at collectors.""" | ||
| cached = self._iteration_path_cache.get(exec_node_id) | ||
| if cached is not None: | ||
| return cached | ||
|
|
||
| # Only prepared execution nodes participate; otherwise treat as non-iterated. | ||
| source_node_id = self.prepared_source_mapping.get(exec_node_id) | ||
| if source_node_id is None: | ||
| self._iteration_path_cache[exec_node_id] = () | ||
| return () | ||
|
|
||
| # Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak. | ||
| it_g = self._iterator_graph(self.graph.nx_graph()) | ||
| iterator_sources = [ | ||
| n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation) | ||
| ] | ||
|
|
||
| # Order iterators outer->inner via topo order of the iterator graph. | ||
| topo = list(nx.topological_sort(it_g)) | ||
| topo_index = {n: i for i, n in enumerate(topo)} | ||
| iterator_sources.sort(key=lambda n: topo_index.get(n, 0)) | ||
|
|
||
| # Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id. | ||
| eg = self.execution_graph.nx_graph() | ||
| path: list[int] = [] | ||
| for it_src in iterator_sources: | ||
| prepared = self.source_prepared_mapping.get(it_src) | ||
| if not prepared: | ||
| continue | ||
| it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None) | ||
| if it_exec is None: | ||
| continue | ||
| it_node = self.execution_graph.nodes.get(it_exec) | ||
| if isinstance(it_node, IterateInvocation): | ||
| path.append(it_node.index) | ||
|
Comment on lines
+850
to
+859
|
||
|
|
||
| # If this exec node is itself an iterator, include its own index as the innermost element. | ||
| node_obj = self.execution_graph.nodes.get(exec_node_id) | ||
| if isinstance(node_obj, IterateInvocation): | ||
| path.append(node_obj.index) | ||
|
|
||
| result = tuple(path) | ||
| self._iteration_path_cache[exec_node_id] = result | ||
| return result | ||
|
|
||
| def _queue_for(self, cls_name: str) -> Deque[str]: | ||
| q = self._ready_queues.get(cls_name) | ||
| if q is None: | ||
|
|
@@ -843,7 +890,15 @@ def _enqueue_if_ready(self, nid: str) -> None: | |
| if self.indegree[nid] != 0 or nid in self.executed: | ||
| return | ||
| node_obj = self.execution_graph.nodes[nid] | ||
| self._queue_for(self._type_key(node_obj)).append(nid) | ||
| q = self._queue_for(self._type_key(node_obj)) | ||
| nid_path = self._get_iteration_path(nid) | ||
| # Insert in lexicographic outer->inner order; preserve FIFO for equal paths. | ||
| for i, existing in enumerate(q): | ||
| if self._get_iteration_path(existing) > nid_path: | ||
| q.insert(i, nid) | ||
| break | ||
| else: | ||
| q.append(nid) | ||
|
Comment on lines
+893
to
+901
|
||
|
|
||
| model_config = ConfigDict( | ||
| json_schema_extra={ | ||
|
|
@@ -1083,12 +1138,12 @@ def no_unexecuted_iter_ancestors(n: str) -> bool: | |
|
|
||
| # Select the correct prepared parents for each iteration | ||
| # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator | ||
| # TODO: Handle a node mapping to none | ||
| eg = self.execution_graph.nx_graph_flat() | ||
| prepared_parent_mappings = [ | ||
| [(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] | ||
| for it in iterator_node_prepared_combinations | ||
| ] # type: ignore | ||
| prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)] | ||
|
|
||
| # Create execution node for each iteration | ||
| for iteration_mappings in prepared_parent_mappings: | ||
|
|
@@ -1110,15 +1165,17 @@ def _get_iteration_node( | |
| if len(prepared_nodes) == 1: | ||
| return next(iter(prepared_nodes)) | ||
|
|
||
| # Check if the requested node is an iterator | ||
| prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) | ||
| if prepared_iterator is not None: | ||
| return prepared_iterator | ||
|
|
||
| # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) | ||
| iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] | ||
| parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)] | ||
|
|
||
| # If the requested node is an iterator, only accept it if it is compatible with all parent iterators | ||
| prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) | ||
| if prepared_iterator is not None: | ||
| if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators): | ||
| return prepared_iterator | ||
| return None | ||
|
|
||
| return next( | ||
| (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), | ||
| None, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use a
@cacheddecorator here?