|
16 | 16 | import gzip
|
17 | 17 | from copy import deepcopy
|
18 | 18 | import cPickle
|
| 19 | +import inspect |
19 | 20 | import os
|
20 | 21 | import shutil
|
21 | 22 | from shutil import rmtree
|
|
37 | 38 | Undefined, TraitedSpec, DynamicTraitedSpec,
|
38 | 39 | Bunch, InterfaceResult, md5, Interface,
|
39 | 40 | TraitDictObject, TraitListObject, isdefined)
|
40 |
| -from ..utils.misc import getsource |
| 41 | +from ..utils.misc import getsource, create_function_from_source |
41 | 42 | from ..utils.filemanip import (save_json, FileNotFoundError,
|
42 | 43 | filename_to_list, list_to_filename,
|
43 | 44 | copyfiles, fnames_presuffix, loadpkl,
|
|
48 | 49 | from .utils import (generate_expanded_graph, modify_paths,
|
49 | 50 | export_graph, make_output_dir,
|
50 | 51 | clean_working_directory, format_dot,
|
51 |
| - get_print_name, merge_dict, |
52 |
| - evaluate_connect_function) |
| 52 | + get_print_name, merge_dict, evaluate_connect_function) |
| 53 | + |
| 54 | + |
| 55 | +def _write_inputs(node): |
| 56 | + lines = [] |
| 57 | + for key, _ in node.inputs.items(): |
| 58 | + val = getattr(node.inputs, key) |
| 59 | + if isdefined(val): |
| 60 | + if type(val) == str: |
| 61 | + lines.append('%s.inputs.%s = "%s"' % (node.name, key, val)) |
| 62 | + else: |
| 63 | + lines.append('%s.inputs.%s = %s' % (node.name, key, val)) |
| 64 | + return lines |
| 65 | + |
| 66 | + |
| 67 | +def format_node(node, format='python'): |
| 68 | + """Format a node in a given output syntax |
| 69 | + """ |
| 70 | + lines = [] |
| 71 | + name = node.name |
| 72 | + if format == 'python': |
| 73 | + klass = node._interface |
| 74 | + importline = 'from %s import %s' % (klass.__module__, |
| 75 | + klass.__class__.__name__) |
| 76 | + comment = '# Node: %s' % node.fullname |
| 77 | + spec = inspect.getargspec(node._interface.__init__) |
| 78 | + if spec.defaults: |
| 79 | + args = spec.args[1:-len(spec.defaults)] |
| 80 | + else: |
| 81 | + args = spec.args[1:] |
| 82 | + if args: |
| 83 | + filled_args = [] |
| 84 | + for arg in args: |
| 85 | + filled_args.append('%s=%s' % (arg, getattr(node._interface, |
| 86 | + '_%s' % arg))) |
| 87 | + args = ', '.join(filled_args) |
| 88 | + else: |
| 89 | + args = '' |
| 90 | + if isinstance(node, MapNode): |
| 91 | + nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' % (name, |
| 92 | + klass.__class__.__name__, |
| 93 | + args, |
| 94 | + node.iterfield, |
| 95 | + name) |
| 96 | + else: |
| 97 | + nodedef = '%s = Node(%s(%s), name="%s")' % (name, |
| 98 | + klass.__class__.__name__, |
| 99 | + args, |
| 100 | + name) |
| 101 | + lines = [importline, comment, nodedef] |
| 102 | + if node.iterables is not None: |
| 103 | + lines.append('%s.iterables = %s' % (name, node.iterables)) |
| 104 | + lines.extend(_write_inputs(node)) |
| 105 | + return lines |
| 106 | + |
53 | 107 |
|
54 | 108 | class WorkflowBase(object):
|
55 | 109 | """ Define common attributes and functions for workflows and nodes
|
@@ -488,17 +542,52 @@ def export(self, prefix="output", format="python"):
|
488 | 542 | flatgraph = self._create_flat_graph()
|
489 | 543 | nodes = nx.topological_sort(flatgraph)
|
490 | 544 |
|
| 545 | + lines = ['# Workflow'] |
| 546 | + importlines = ['from nipype.pipeline.engine import Workflow, Node, MapNode'] |
| 547 | + functions = {} |
491 | 548 | if format == "python":
|
492 |
| - with open('%s.py', 'wt') as fp: |
493 |
| - nodenames = [] |
494 |
| - for idx, node in enumerate(nodes): |
495 |
| - # write nodes |
496 |
| - nodestr, nodename = node.format(format=python) |
497 |
| - fp.writelines(nodestr) |
498 |
| - # write connections |
499 |
| - for prevnode in flatgraph.predecessors(node): |
500 |
| - # write connection |
501 |
| - pass |
| 549 | + connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name |
| 550 | + connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' % self.name |
| 551 | + wfdef = '%s = Workflow("%s")' % (self.name, self.name) |
| 552 | + lines.append(wfdef) |
| 553 | + for idx, node in enumerate(nodes): |
| 554 | + nodename = node.name |
| 555 | + # write nodes |
| 556 | + nodelines = format_node(node, format='python') |
| 557 | + for line in nodelines: |
| 558 | + if line.startswith('from'): |
| 559 | + if line not in importlines: |
| 560 | + importlines.append(line) |
| 561 | + else: |
| 562 | + lines.append(line) |
| 563 | + # write connections |
| 564 | + for u, _, d in flatgraph.in_edges_iter(nbunch=node, |
| 565 | + data=True): |
| 566 | + for cd in d['connect']: |
| 567 | + if isinstance(cd[0], tuple): |
| 568 | + args = list(cd[0]) |
| 569 | + if args[1] in functions: |
| 570 | + funcname = functions[args[1]] |
| 571 | + else: |
| 572 | + func = create_function_from_source(args[1]) |
| 573 | + funcname = [name for name in func.func_globals if name != '__builtins__'][0] |
| 574 | + functions[args[1]] = funcname |
| 575 | + args[1] = funcname |
| 576 | + args = tuple([arg for arg in args if arg]) |
| 577 | + line = connect_template % (u.name, args, |
| 578 | + nodename, cd[1]) |
| 579 | + line = line.replace("'%s'" % funcname, funcname) |
| 580 | + lines.append(line) |
| 581 | + else: |
| 582 | + lines.append(connect_template2 % (u.name, cd[0], |
| 583 | + nodename, cd[1])) |
| 584 | + functionlines = ['# Functions'] |
| 585 | + for function in functions: |
| 586 | + functionlines.append(cPickle.loads(function).rstrip()) |
| 587 | + all_lines = importlines + functionlines + lines |
| 588 | + with open('%s%s.py' % (prefix, self.name), 'wt') as fp: |
| 589 | + fp.writelines('\n'.join([line.replace('\n', '\\n') for line in all_lines])) |
| 590 | + return all_lines |
502 | 591 |
|
503 | 592 | def run(self, plugin=None, plugin_args=None, updatehash=False):
|
504 | 593 | """ Execute the workflow
|
@@ -1062,11 +1151,6 @@ def help(self):
|
1062 | 1151 | """ Print interface help"""
|
1063 | 1152 | self._interface.help()
|
1064 | 1153 |
|
1065 |
| - def format(format=python): |
1066 |
| - """Format a node in a given output syntax |
1067 |
| - """ |
1068 |
| - |
1069 |
| - |
1070 | 1154 |
|
1071 | 1155 | def hash_exists(self, updatehash=False):
|
1072 | 1156 | # Get a dictionary with hashed filenames and a hashvalue
|
|
0 commit comments