@@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
59
59
super (Workflow , self ).__init__ (name , base_dir )
60
60
self ._graph = nx .DiGraph ()
61
61
62
+ self ._nodes_cache = set ()
63
+ self ._nested_workflows_cache = set ()
64
+
62
65
# PUBLIC API
63
66
def clone (self , name ):
64
67
"""Clone a workflow
@@ -269,6 +272,8 @@ def connect(self, *args, **kwargs):
269
272
"(%s, %s): new edge data: %s" , srcnode , destnode , str (edge_data )
270
273
)
271
274
275
+ self ._update_node_cache ()
276
+
272
277
def disconnect (self , * args ):
273
278
"""Disconnect nodes
274
279
See the docstring for connect for format.
@@ -314,6 +319,8 @@ def disconnect(self, *args):
314
319
else :
315
320
self ._graph .add_edges_from ([(srcnode , dstnode , edge_data )])
316
321
322
+ self ._update_node_cache ()
323
+
317
324
def add_nodes (self , nodes ):
318
325
""" Add nodes to a workflow
319
326
@@ -346,6 +353,7 @@ def add_nodes(self, nodes):
346
353
if node ._hierarchy is None :
347
354
node ._hierarchy = self .name
348
355
self ._graph .add_nodes_from (newnodes )
356
+ self ._update_node_cache ()
349
357
350
358
def remove_nodes (self , nodes ):
351
359
""" Remove nodes from a workflow
@@ -356,6 +364,7 @@ def remove_nodes(self, nodes):
356
364
A list of EngineBase-based objects
357
365
"""
358
366
self ._graph .remove_nodes_from (nodes )
367
+ self ._update_node_cache ()
359
368
360
369
# Input-Output access
361
370
@property
@@ -903,23 +912,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
903
912
node .set_input (param , deepcopy (newval ))
904
913
905
914
def _get_all_nodes (self ):
906
- allnodes = []
907
- for node in self ._graph .nodes ():
908
- if isinstance (node , Workflow ):
909
- allnodes .extend (node ._get_all_nodes ())
910
- else :
911
- allnodes .append (node )
915
+ allnodes = [
916
+ * self ._nodes_cache .difference (self ._nested_workflows_cache )
917
+ ] # all nodes that are not workflows
918
+ for node in self ._nested_workflows_cache :
919
+ allnodes .extend (node ._get_all_nodes ())
912
920
return allnodes
913
921
922
+ def _update_node_cache (self ):
923
+ nodes = set (self ._graph )
924
+
925
+ added_nodes = nodes .difference (self ._nodes_cache )
926
+ removed_nodes = self ._nodes_cache .difference (nodes )
927
+
928
+ self ._nodes_cache = nodes
929
+ self ._nested_workflows_cache .difference_update (removed_nodes )
930
+
931
+ for node in added_nodes :
932
+ if isinstance (node , Workflow ):
933
+ self ._nested_workflows_cache .add (node )
934
+
914
935
def _has_node (self , wanted_node ):
915
- if wanted_node in self ._graph :
916
- return True # best case scenario
917
- for node in self ._graph : # iterate otherwise
918
- if wanted_node == node :
936
+ if wanted_node in self ._nodes_cache :
937
+ return True
938
+ for node in self ._nested_workflows_cache :
939
+ if node . _has_node ( wanted_node ) :
919
940
return True
920
- if hasattr (node , "_has_node" ): # hasattr is faster than isinstance
921
- if node ._has_node (wanted_node ):
922
- return True
923
941
return False
924
942
925
943
def _create_flat_graph (self ):
0 commit comments