11"""Workflow splicing operations."""
22
3+ import logging
34import typing as ty
45
56from nipype .pipeline import Workflow
67from nipype .pipeline .engine .base import EngineBase
78
89
9- def tag (tag : str ) -> 'EngineBase' :
10+ def tag (tag : str ) -> ty . Callable :
1011 """
1112 Decorator to set a tag on an `init_...wf` function.
1213
1314 This is used to mark nodes or workflows for replacement in the splicing process.
1415 """
1516
1617 def _decorator (func , * args , ** kwargs ) -> ty .Callable :
17- def _tag () -> ' EngineBase' :
18+ def _tag () -> EngineBase :
1819 node = func (* args , ** kwargs )
1920 node ._tag = tag
2021 return node
@@ -26,7 +27,7 @@ def _tag() -> 'EngineBase':
2627
2728def splice_workflow (
2829 root_wf : Workflow ,
29- replacements : dict [str , ' EngineBase' ],
30+ replacements : dict [str , EngineBase ],
3031 * ,
3132 write_graph : bool = False ,
3233 debug : bool = False ,
@@ -48,19 +49,19 @@ def splice_workflow(
4849
4950
5051def _get_substitutions (
51- wf : Workflow ,
52- replacements : dict [str , ' EngineBase' ],
53- ) -> dict [' EngineBase' , ' EngineBase' ]:
52+ workflow : Workflow ,
53+ replacements : dict [str , EngineBase ],
54+ ) -> dict [EngineBase , EngineBase ]:
5455 """ "Query tags in workflow, and return a list of substitutions to make"""
5556 substitutions = {}
56- tagged_wfs = _fetch_tags (wf )
57+ tagged_wfs = _fetch_tags (workflow )
5758 for tag in tagged_wfs :
5859 if tag in replacements :
5960 substitutions [tagged_wfs [tag ]] = replacements [tag ]
6061 return substitutions
6162
6263
63- def _fetch_tags (wf : Workflow ) -> dict [str , ' EngineBase' ]:
64+ def _fetch_tags (wf : Workflow ) -> dict [str , EngineBase ]:
6465 """Query all nodes in a workflow and return a dictionary of tags and nodes."""
6566 tagged = {}
6667 for node in wf ._graph .nodes :
@@ -74,7 +75,7 @@ def _fetch_tags(wf: Workflow) -> dict[str, 'EngineBase']:
7475
7576def _splice_components (
7677 workflow : Workflow ,
77- substitutions : dict [' EngineBase' , ' EngineBase' ],
78+ substitutions : dict [EngineBase , EngineBase ],
7879 debug : bool = False ,
7980) -> tuple [list , list ]:
8081 """Query all connections and return a list of removals and additions to be made."""
@@ -86,7 +87,7 @@ def _splice_components(
8687
8788 to_replace = [x .fullname for x in substitutions ]
8889
89- for src , dst in workflow ._graph .edges : # will not expand workflows, but needs to
90+ for src , dst in workflow ._graph .edges :
9091 if dst .fullname in to_replace :
9192 edge_data = workflow ._graph .get_edge_data (src , dst )
9293 alt_dst = substitutions [dst ]
@@ -112,10 +113,13 @@ def _splice_components(
112113 _expanded_workflows .add (src )
113114 _splice_components (src , substitutions , debug = debug )
114115
115- if debug :
116- print (f'Workflow: { workflow } ' )
117- print (f'- Removing: { edge_removals } ' )
118- print (f'+ Adding: { edge_connects } ' )
116+ logger = logging .getLogger ('nipype.workflow' )
117+ logger .debug (
118+ 'Workflow: %s, \n - edge_removals: %s, \n + edge_connects: %s' ,
119+ workflow ,
120+ edge_removals ,
121+ edge_connects ,
122+ )
119123
120124 workflow ._graph .remove_edges_from (edge_removals )
121125 workflow .remove_nodes (node_removals )
0 commit comments