Skip to content

Commit 9e23f15

Browse files
committed
Merge pull request #517 from satra/enh/saveworkflow
Saving workflow to a file
2 parents 4b86d97 + ba84e10 commit 9e23f15

File tree

6 files changed

+266
-8
lines changed

6 files changed

+266
-8
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Next release
33

44
* ENH: New interfaces: nipy.Trim, fsl.GLM, fsl.SigLoss, spm.VBMSegment, fsl.InvWarp
55
* ENH: Allow control over terminal output for commandline interfaces
6+
* ENH: Added preliminary support for generating Python code from Workflows.
67
* ENH: New workflows for dMRI and fMRI pre-processing: added motion artifact correction
78
with rotation of the B-matrix, and susceptibility correction for EPI imaging using
89
fieldmaps. Updated eddy_correct pipeline to support both dMRI and fMRI, and new parameters.

doc/users/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
function_interface
3232
mapnode_and_iterables
3333
model_specification
34+
saving_workflows
3435

3536

3637

doc/users/saving_workflows.rst

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
.. _saving_workflows:
2+
3+
===================================================
4+
Saving Workflows and Nodes to a file (experimental)
5+
===================================================
6+
7+
On top of the standard way of saving (i.e. serializing) objects in Python
8+
(see `pickle <http://docs.python.org/2/library/pickle.html>`_) Nipype
9+
provides methods to turn Workflows and nodes into human readable code.
10+
This is useful if you want to save a Workflow that you have generated
11+
on the fly for future use.
12+
13+
To generate Python code for a Workflow use the export method:
14+
15+
.. testcode::
16+
17+
from nipype.interfaces.fsl import BET, ImageMaths
18+
from nipype.pipeline.engine import Workflow, Node, MapNode, format_node
19+
from nipype.interfaces.utility import Function, IdentityInterface
20+
21+
bet = Node(BET(), name='bet')
22+
bet.iterables = ('frac', [0.3, 0.4])
23+
24+
bet2 = MapNode(BET(), name='bet2', iterfield=['infile'])
25+
bet2.iterables = ('frac', [0.4, 0.5])
26+
27+
maths = Node(ImageMaths(), name='maths')
28+
29+
def testfunc(in1):
30+
"""dummy func
31+
"""
32+
out = in1 + 'foo' + "out1"
33+
return out
34+
35+
funcnode = Node(Function(input_names=['a'], output_names=['output'], function=testfunc),
36+
name='testfunc')
37+
funcnode.inputs.in1 = '-sub'
38+
func = lambda x: x
39+
40+
inode = Node(IdentityInterface(fields=['a']), name='inode')
41+
42+
wf = Workflow('testsave')
43+
wf.add_nodes([bet2])
44+
wf.connect(bet, 'mask_file', maths, 'in_file')
45+
wf.connect(bet2, ('mask_file', func), maths, 'in_file2')
46+
wf.connect(inode, 'a', funcnode, 'in1')
47+
wf.connect(funcnode, 'output', maths, 'op_string')
48+
49+
wf.export()
50+
51+
This will create a file "outputtestsave.py" with the following content:
52+
53+
.. testcode::
54+
55+
from nipype.pipeline.engine import Workflow, Node, MapNode
56+
from nipype.interfaces.utility import IdentityInterface
57+
from nipype.interfaces.utility import Function
58+
from nipype.utils.misc import getsource
59+
from nipype.interfaces.fsl.preprocess import BET
60+
from nipype.interfaces.fsl.utils import ImageMaths
61+
# Functions
62+
func = lambda x: x
63+
# Workflow
64+
testsave = Workflow("testsave")
65+
# Node: testsave.inode
66+
inode = Node(IdentityInterface(fields=['a'], mandatory_inputs=True), name="inode")
67+
# Node: testsave.testfunc
68+
testfunc = Node(Function(input_names=['a'], output_names=['output']), name="testfunc")
69+
def testfunc_1(in1):
70+
"""dummy func
71+
"""
72+
out = in1 + 'foo' + "out1"
73+
return out
74+
75+
testfunc.inputs.function_str = getsource(testfunc_1)
76+
testfunc.inputs.ignore_exception = False
77+
testfunc.inputs.in1 = '-sub'
78+
testsave.connect(inode, "a", testfunc, "in1")
79+
# Node: testsave.bet2
80+
bet2 = MapNode(BET(), iterfield=['infile'], name="bet2")
81+
bet2.iterables = ('frac', [0.4, 0.5])
82+
bet2.inputs.environ = {'FSLOUTPUTTYPE': 'NIFTI_GZ'}
83+
bet2.inputs.ignore_exception = False
84+
bet2.inputs.output_type = 'NIFTI_GZ'
85+
bet2.inputs.terminal_output = 'stream'
86+
# Node: testsave.bet
87+
bet = Node(BET(), name="bet")
88+
bet.iterables = ('frac', [0.3, 0.4])
89+
bet.inputs.environ = {'FSLOUTPUTTYPE': 'NIFTI_GZ'}
90+
bet.inputs.ignore_exception = False
91+
bet.inputs.output_type = 'NIFTI_GZ'
92+
bet.inputs.terminal_output = 'stream'
93+
# Node: testsave.maths
94+
maths = Node(ImageMaths(), name="maths")
95+
maths.inputs.environ = {'FSLOUTPUTTYPE': 'NIFTI_GZ'}
96+
maths.inputs.ignore_exception = False
97+
maths.inputs.output_type = 'NIFTI_GZ'
98+
maths.inputs.terminal_output = 'stream'
99+
testsave.connect(bet2, ('mask_file', func), maths, "in_file2")
100+
testsave.connect(bet, "mask_file", maths, "in_file")
101+
testsave.connect(testfunc, "output", maths, "op_string")
102+
103+
The file is ready to use and includes all the necessary imports.
104+
105+
.. include:: ../links_names.txt

nipype/interfaces/io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def __init__(self, infields=None, outfields=None, **kwargs):
439439
undefined_traits = {}
440440
# used for mandatory inputs check
441441
self._infields = infields
442+
self._outfields = outfields
442443
if infields:
443444
for key in infields:
444445
self.inputs.add_trait(key, traits.Any)

nipype/interfaces/utility.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ class Merge(IOBase):
108108

109109
def __init__(self, numinputs=0, **inputs):
110110
super(Merge, self).__init__(**inputs)
111-
self.numinputs = numinputs
111+
self._numinputs = numinputs
112112
add_traits(self.inputs, ['in%d' % (i + 1) for i in range(numinputs)])
113113

114114
def _list_outputs(self):
115115
outputs = self._outputs().get()
116116
out = []
117117
if self.inputs.axis == 'vstack':
118-
for idx in range(self.numinputs):
118+
for idx in range(self._numinputs):
119119
value = getattr(self.inputs, 'in%d' % (idx + 1))
120120
if isdefined(value):
121121
if isinstance(value, list) and not self.inputs.no_flatten:
@@ -125,7 +125,7 @@ def _list_outputs(self):
125125
else:
126126
for i in range(len(filename_to_list(self.inputs.in1))):
127127
out.insert(i, [])
128-
for j in range(self.numinputs):
128+
for j in range(self._numinputs):
129129
out[i].append(filename_to_list(getattr(self.inputs, 'in%d' % (j + 1)))[i])
130130
if out:
131131
outputs['out'] = out

nipype/pipeline/engine.py

Lines changed: 155 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import gzip
1717
from copy import deepcopy
1818
import cPickle
19+
import inspect
1920
import os
2021
import shutil
2122
from shutil import rmtree
@@ -37,7 +38,7 @@
3738
Undefined, TraitedSpec, DynamicTraitedSpec,
3839
Bunch, InterfaceResult, md5, Interface,
3940
TraitDictObject, TraitListObject, isdefined)
40-
from ..utils.misc import getsource
41+
from ..utils.misc import getsource, create_function_from_source
4142
from ..utils.filemanip import (save_json, FileNotFoundError,
4243
filename_to_list, list_to_filename,
4344
copyfiles, fnames_presuffix, loadpkl,
@@ -48,8 +49,80 @@
4849
from .utils import (generate_expanded_graph, modify_paths,
4950
export_graph, make_output_dir,
5051
clean_working_directory, format_dot,
51-
get_print_name, merge_dict,
52-
evaluate_connect_function)
52+
get_print_name, merge_dict, evaluate_connect_function)
53+
54+
55+
def _write_inputs(node):
56+
lines = []
57+
nodename = node.fullname.replace('.', '_')
58+
for key, _ in node.inputs.items():
59+
val = getattr(node.inputs, key)
60+
if isdefined(val):
61+
if type(val) == str:
62+
try:
63+
func = create_function_from_source(val)
64+
except RuntimeError, e:
65+
lines.append("%s.inputs.%s = '%s'" % (nodename, key, val))
66+
else:
67+
funcname = [name for name in func.func_globals if name != '__builtins__'][0]
68+
lines.append(cPickle.loads(val))
69+
if funcname == nodename:
70+
lines[-1] = lines[-1].replace(' %s(' % funcname,
71+
' %s_1(' % funcname)
72+
funcname = '%s_1' % funcname
73+
lines.append('from nipype.utils.misc import getsource')
74+
lines.append("%s.inputs.%s = getsource(%s)" % (nodename,
75+
key,
76+
funcname))
77+
else:
78+
lines.append('%s.inputs.%s = %s' % (nodename, key, val))
79+
return lines
80+
81+
82+
def format_node(node, format='python', include_config=False):
83+
"""Format a node in a given output syntax
84+
"""
85+
lines = []
86+
name = node.fullname.replace('.', '_')
87+
if format == 'python':
88+
klass = node._interface
89+
importline = 'from %s import %s' % (klass.__module__,
90+
klass.__class__.__name__)
91+
comment = '# Node: %s' % node.fullname
92+
spec = inspect.getargspec(node._interface.__init__)
93+
args = spec.args[1:]
94+
if args:
95+
filled_args = []
96+
for arg in args:
97+
if hasattr(node._interface, '_%s' % arg):
98+
filled_args.append('%s=%s' % (arg, getattr(node._interface,
99+
'_%s' % arg)))
100+
args = ', '.join(filled_args)
101+
else:
102+
args = ''
103+
if isinstance(node, MapNode):
104+
nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' % (name,
105+
klass.__class__.__name__,
106+
args,
107+
node.iterfield,
108+
name)
109+
else:
110+
nodedef = '%s = Node(%s(%s), name="%s")' % (name,
111+
klass.__class__.__name__,
112+
args,
113+
name)
114+
lines = [importline, comment, nodedef]
115+
116+
if include_config:
117+
lines = [importline, "from collections import OrderedDict", comment, nodedef]
118+
lines.append('%s.config = %s' % (name, node.config))
119+
120+
if node.iterables is not None:
121+
lines.append('%s.iterables = %s' % (name, node.iterables))
122+
lines.extend(_write_inputs(node))
123+
124+
return lines
125+
53126

54127
class WorkflowBase(object):
55128
""" Define common attributes and functions for workflows and nodes
@@ -139,8 +212,8 @@ class Workflow(WorkflowBase):
139212
"""Controls the setup and execution of a pipeline of processes
140213
"""
141214

142-
def __init__(self, **kwargs):
143-
super(Workflow, self).__init__(**kwargs)
215+
def __init__(self, *args, **kwargs):
216+
super(Workflow, self).__init__(* args, **kwargs)
144217
self._graph = nx.DiGraph()
145218

146219
# PUBLIC API
@@ -470,6 +543,82 @@ def write_hierarchical_dotfile(self, dotfilename=None, colored=True,
470543
else:
471544
logger.info(dotstr)
472545

546+
def export(self, filename=None, prefix="output", format="python", include_config=False):
547+
"""Export object into a different format
548+
549+
Parameters
550+
----------
551+
filename: string
552+
file to save the code to; overrides prefix
553+
554+
prefix: string
555+
prefix to use for output file
556+
557+
format: string
558+
one of "python"
559+
560+
include_config: boolean
561+
whether to include node and workflow config values
562+
"""
563+
formats = ["python"]
564+
if format not in formats:
565+
raise ValueError('format must be one of: %s' % '|'.join(formats))
566+
flatgraph = self._create_flat_graph()
567+
nodes = nx.topological_sort(flatgraph)
568+
569+
lines = ['# Workflow']
570+
importlines = ['from nipype.pipeline.engine import Workflow, Node, MapNode']
571+
functions = {}
572+
if format == "python":
573+
connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name
574+
connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' % self.name
575+
wfdef = '%s = Workflow("%s")' % (self.name, self.name)
576+
lines.append(wfdef)
577+
if include_config:
578+
lines.append('%s.config = %s' % (self.name, self.config))
579+
for idx, node in enumerate(nodes):
580+
nodename = node.fullname.replace('.', '_')
581+
# write nodes
582+
nodelines = format_node(node, format='python', include_config=include_config)
583+
for line in nodelines:
584+
if line.startswith('from'):
585+
if line not in importlines:
586+
importlines.append(line)
587+
else:
588+
lines.append(line)
589+
# write connections
590+
for u, _, d in flatgraph.in_edges_iter(nbunch=node,
591+
data=True):
592+
for cd in d['connect']:
593+
if isinstance(cd[0], tuple):
594+
args = list(cd[0])
595+
if args[1] in functions:
596+
funcname = functions[args[1]]
597+
else:
598+
func = create_function_from_source(args[1])
599+
funcname = [name for name in func.func_globals if name != '__builtins__'][0]
600+
functions[args[1]] = funcname
601+
args[1] = funcname
602+
args = tuple([arg for arg in args if arg])
603+
line = connect_template % (u.fullname.replace('.','_'), args,
604+
nodename, cd[1])
605+
line = line.replace("'%s'" % funcname, funcname)
606+
lines.append(line)
607+
else:
608+
lines.append(connect_template2 % (u.fullname.replace('.','_'), cd[0],
609+
nodename, cd[1]))
610+
functionlines = ['# Functions']
611+
for function in functions:
612+
functionlines.append(cPickle.loads(function).rstrip())
613+
all_lines = importlines + functionlines + lines
614+
615+
if not filename:
616+
filename = '%s%s.py' % (prefix, self.name)
617+
with open(filename, 'wt') as fp:
618+
#fp.writelines('\n'.join([line.replace('\n', '\\n') for line in all_lines]))
619+
fp.writelines('\n'.join(all_lines))
620+
return all_lines
621+
473622
def run(self, plugin=None, plugin_args=None, updatehash=False):
474623
""" Execute the workflow
475624
@@ -1032,6 +1181,7 @@ def help(self):
10321181
""" Print interface help"""
10331182
self._interface.help()
10341183

1184+
10351185
def hash_exists(self, updatehash=False):
10361186
# Get a dictionary with hashed filenames and a hashvalue
10371187
# of the dictionary itself.

0 commit comments

Comments
 (0)