Skip to content

Commit 3fcb646

Browse files
committed
enh: first pass at saving a flattened workflow to a python file
1 parent 9860cd7 commit 3fcb646

File tree

1 file changed

+102
-18
lines changed

1 file changed

+102
-18
lines changed

nipype/pipeline/engine.py

Lines changed: 102 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import gzip
1717
from copy import deepcopy
1818
import cPickle
19+
import inspect
1920
import os
2021
import shutil
2122
from shutil import rmtree
@@ -37,7 +38,7 @@
3738
Undefined, TraitedSpec, DynamicTraitedSpec,
3839
Bunch, InterfaceResult, md5, Interface,
3940
TraitDictObject, TraitListObject, isdefined)
40-
from ..utils.misc import getsource
41+
from ..utils.misc import getsource, create_function_from_source
4142
from ..utils.filemanip import (save_json, FileNotFoundError,
4243
filename_to_list, list_to_filename,
4344
copyfiles, fnames_presuffix, loadpkl,
@@ -48,8 +49,61 @@
4849
from .utils import (generate_expanded_graph, modify_paths,
4950
export_graph, make_output_dir,
5051
clean_working_directory, format_dot,
51-
get_print_name, merge_dict,
52-
evaluate_connect_function)
52+
get_print_name, merge_dict, evaluate_connect_function)
53+
54+
55+
def _write_inputs(node):
56+
lines = []
57+
for key, _ in node.inputs.items():
58+
val = getattr(node.inputs, key)
59+
if isdefined(val):
60+
if type(val) == str:
61+
lines.append('%s.inputs.%s = "%s"' % (node.name, key, val))
62+
else:
63+
lines.append('%s.inputs.%s = %s' % (node.name, key, val))
64+
return lines
65+
66+
67+
def format_node(node, format='python'):
68+
"""Format a node in a given output syntax
69+
"""
70+
lines = []
71+
name = node.name
72+
if format == 'python':
73+
klass = node._interface
74+
importline = 'from %s import %s' % (klass.__module__,
75+
klass.__class__.__name__)
76+
comment = '# Node: %s' % node.fullname
77+
spec = inspect.getargspec(node._interface.__init__)
78+
if spec.defaults:
79+
args = spec.args[1:-len(spec.defaults)]
80+
else:
81+
args = spec.args[1:]
82+
if args:
83+
filled_args = []
84+
for arg in args:
85+
filled_args.append('%s=%s' % (arg, getattr(node._interface,
86+
'_%s' % arg)))
87+
args = ', '.join(filled_args)
88+
else:
89+
args = ''
90+
if isinstance(node, MapNode):
91+
nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' % (name,
92+
klass.__class__.__name__,
93+
args,
94+
node.iterfield,
95+
name)
96+
else:
97+
nodedef = '%s = Node(%s(%s), name="%s")' % (name,
98+
klass.__class__.__name__,
99+
args,
100+
name)
101+
lines = [importline, comment, nodedef]
102+
if node.iterables is not None:
103+
lines.append('%s.iterables = %s' % (name, node.iterables))
104+
lines.extend(_write_inputs(node))
105+
return lines
106+
53107

54108
class WorkflowBase(object):
55109
""" Define common attributes and functions for workflows and nodes
@@ -488,17 +542,52 @@ def export(self, prefix="output", format="python"):
488542
flatgraph = self._create_flat_graph()
489543
nodes = nx.topological_sort(flatgraph)
490544

545+
lines = ['# Workflow']
546+
importlines = ['from nipype.pipeline.engine import Workflow, Node, MapNode']
547+
functions = {}
491548
if format == "python":
492-
with open('%s.py', 'wt') as fp:
493-
nodenames = []
494-
for idx, node in enumerate(nodes):
495-
# write nodes
496-
nodestr, nodename = node.format(format=python)
497-
fp.writelines(nodestr)
498-
# write connections
499-
for prevnode in flatgraph.predecessors(node):
500-
# write connection
501-
pass
549+
connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name
550+
connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' % self.name
551+
wfdef = '%s = Workflow("%s")' % (self.name, self.name)
552+
lines.append(wfdef)
553+
for idx, node in enumerate(nodes):
554+
nodename = node.name
555+
# write nodes
556+
nodelines = format_node(node, format='python')
557+
for line in nodelines:
558+
if line.startswith('from'):
559+
if line not in importlines:
560+
importlines.append(line)
561+
else:
562+
lines.append(line)
563+
# write connections
564+
for u, _, d in flatgraph.in_edges_iter(nbunch=node,
565+
data=True):
566+
for cd in d['connect']:
567+
if isinstance(cd[0], tuple):
568+
args = list(cd[0])
569+
if args[1] in functions:
570+
funcname = functions[args[1]]
571+
else:
572+
func = create_function_from_source(args[1])
573+
funcname = [name for name in func.func_globals if name != '__builtins__'][0]
574+
functions[args[1]] = funcname
575+
args[1] = funcname
576+
args = tuple([arg for arg in args if arg])
577+
line = connect_template % (u.name, args,
578+
nodename, cd[1])
579+
line = line.replace("'%s'" % funcname, funcname)
580+
lines.append(line)
581+
else:
582+
lines.append(connect_template2 % (u.name, cd[0],
583+
nodename, cd[1]))
584+
functionlines = ['# Functions']
585+
for function in functions:
586+
functionlines.append(cPickle.loads(function).rstrip())
587+
all_lines = importlines + functionlines + lines
588+
with open('%s%s.py' % (prefix, self.name), 'wt') as fp:
589+
fp.writelines('\n'.join([line.replace('\n', '\\n') for line in all_lines]))
590+
return all_lines
502591

503592
def run(self, plugin=None, plugin_args=None, updatehash=False):
504593
""" Execute the workflow
@@ -1062,11 +1151,6 @@ def help(self):
10621151
""" Print interface help"""
10631152
self._interface.help()
10641153

1065-
def format(format=python):
1066-
"""Format a node in a given output syntax
1067-
"""
1068-
1069-
10701154

10711155
def hash_exists(self, updatehash=False):
10721156
# Get a dictionary with hashed filenames and a hashvalue

0 commit comments

Comments
 (0)