Skip to content

Commit 1c789ad

Browse files
Add a cache for nested workflows
- Update every time we connect/disconenct or add/remove a node - Keep track of which nodes are workflows and which are not - As a result, we do not need to iterate all nodes to determine the result of `_has_node`, we can use O(1) set operations
1 parent 176cff0 commit 1c789ad

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

nipype/pipeline/engine/workflows.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
5959
super(Workflow, self).__init__(name, base_dir)
6060
self._graph = nx.DiGraph()
6161

62+
self._nodes_cache = set()
63+
self._nested_workflows_cache = set()
64+
6265
# PUBLIC API
6366
def clone(self, name):
6467
"""Clone a workflow
@@ -269,6 +272,8 @@ def connect(self, *args, **kwargs):
269272
"(%s, %s): new edge data: %s", srcnode, destnode, str(edge_data)
270273
)
271274

275+
self._update_node_cache()
276+
272277
def disconnect(self, *args):
273278
"""Disconnect nodes
274279
See the docstring for connect for format.
@@ -314,6 +319,8 @@ def disconnect(self, *args):
314319
else:
315320
self._graph.add_edges_from([(srcnode, dstnode, edge_data)])
316321

322+
self._update_node_cache()
323+
317324
def add_nodes(self, nodes):
318325
""" Add nodes to a workflow
319326
@@ -346,6 +353,7 @@ def add_nodes(self, nodes):
346353
if node._hierarchy is None:
347354
node._hierarchy = self.name
348355
self._graph.add_nodes_from(newnodes)
356+
self._update_node_cache()
349357

350358
def remove_nodes(self, nodes):
351359
""" Remove nodes from a workflow
@@ -356,6 +364,7 @@ def remove_nodes(self, nodes):
356364
A list of EngineBase-based objects
357365
"""
358366
self._graph.remove_nodes_from(nodes)
367+
self._update_node_cache()
359368

360369
# Input-Output access
361370
@property
@@ -903,23 +912,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
903912
node.set_input(param, deepcopy(newval))
904913

905914
def _get_all_nodes(self):
906-
allnodes = []
907-
for node in self._graph.nodes():
908-
if isinstance(node, Workflow):
909-
allnodes.extend(node._get_all_nodes())
910-
else:
911-
allnodes.append(node)
915+
allnodes = [
916+
*self._nodes_cache.difference(self._nested_workflows_cache)
917+
] # all nodes that are not workflows
918+
for node in self._nested_workflows_cache:
919+
allnodes.extend(node._get_all_nodes())
912920
return allnodes
913921

922+
def _update_node_cache(self):
923+
nodes = set(self._graph)
924+
925+
added_nodes = nodes.difference(self._nodes_cache)
926+
removed_nodes = self._nodes_cache.difference(nodes)
927+
928+
self._nodes_cache = nodes
929+
self._nested_workflows_cache.difference_update(removed_nodes)
930+
931+
for node in added_nodes:
932+
if isinstance(node, Workflow):
933+
self._nested_workflows_cache.add(node)
934+
914935
def _has_node(self, wanted_node):
915-
if wanted_node in self._graph:
916-
return True # best case scenario
917-
for node in self._graph: # iterate otherwise
918-
if wanted_node == node:
936+
if wanted_node in self._nodes_cache:
937+
return True
938+
for node in self._nested_workflows_cache:
939+
if node._has_node(wanted_node):
919940
return True
920-
if hasattr(node, "_has_node"): # hasattr is faster than isinstance
921-
if node._has_node(wanted_node):
922-
return True
923941
return False
924942

925943
def _create_flat_graph(self):

0 commit comments

Comments
 (0)