|
16 | 16 | import gzip
|
17 | 17 | from copy import deepcopy
|
18 | 18 | import cPickle
|
| 19 | +import inspect |
19 | 20 | import os
|
20 | 21 | import shutil
|
21 | 22 | from shutil import rmtree
|
|
37 | 38 | Undefined, TraitedSpec, DynamicTraitedSpec,
|
38 | 39 | Bunch, InterfaceResult, md5, Interface,
|
39 | 40 | TraitDictObject, TraitListObject, isdefined)
|
40 |
| -from ..utils.misc import getsource |
| 41 | +from ..utils.misc import getsource, create_function_from_source |
41 | 42 | from ..utils.filemanip import (save_json, FileNotFoundError,
|
42 | 43 | filename_to_list, list_to_filename,
|
43 | 44 | copyfiles, fnames_presuffix, loadpkl,
|
|
48 | 49 | from .utils import (generate_expanded_graph, modify_paths,
|
49 | 50 | export_graph, make_output_dir,
|
50 | 51 | clean_working_directory, format_dot,
|
51 |
| - get_print_name, merge_dict, |
52 |
| - evaluate_connect_function) |
| 52 | + get_print_name, merge_dict, evaluate_connect_function) |
| 53 | + |
| 54 | + |
| 55 | +def _write_inputs(node): |
| 56 | + lines = [] |
| 57 | + nodename = node.fullname.replace('.', '_') |
| 58 | + for key, _ in node.inputs.items(): |
| 59 | + val = getattr(node.inputs, key) |
| 60 | + if isdefined(val): |
| 61 | + if type(val) == str: |
| 62 | + try: |
| 63 | + func = create_function_from_source(val) |
| 64 | + except RuntimeError, e: |
| 65 | + lines.append("%s.inputs.%s = '%s'" % (nodename, key, val)) |
| 66 | + else: |
| 67 | + funcname = [name for name in func.func_globals if name != '__builtins__'][0] |
| 68 | + lines.append(cPickle.loads(val)) |
| 69 | + if funcname == nodename: |
| 70 | + lines[-1] = lines[-1].replace(' %s(' % funcname, |
| 71 | + ' %s_1(' % funcname) |
| 72 | + funcname = '%s_1' % funcname |
| 73 | + lines.append('from nipype.utils.misc import getsource') |
| 74 | + lines.append("%s.inputs.%s = getsource(%s)" % (nodename, |
| 75 | + key, |
| 76 | + funcname)) |
| 77 | + else: |
| 78 | + lines.append('%s.inputs.%s = %s' % (nodename, key, val)) |
| 79 | + return lines |
| 80 | + |
| 81 | + |
| 82 | +def format_node(node, format='python', include_config=False): |
| 83 | + """Format a node in a given output syntax |
| 84 | + """ |
| 85 | + lines = [] |
| 86 | + name = node.fullname.replace('.', '_') |
| 87 | + if format == 'python': |
| 88 | + klass = node._interface |
| 89 | + importline = 'from %s import %s' % (klass.__module__, |
| 90 | + klass.__class__.__name__) |
| 91 | + comment = '# Node: %s' % node.fullname |
| 92 | + spec = inspect.getargspec(node._interface.__init__) |
| 93 | + args = spec.args[1:] |
| 94 | + if args: |
| 95 | + filled_args = [] |
| 96 | + for arg in args: |
| 97 | + if hasattr(node._interface, '_%s' % arg): |
| 98 | + filled_args.append('%s=%s' % (arg, getattr(node._interface, |
| 99 | + '_%s' % arg))) |
| 100 | + args = ', '.join(filled_args) |
| 101 | + else: |
| 102 | + args = '' |
| 103 | + if isinstance(node, MapNode): |
| 104 | + nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' % (name, |
| 105 | + klass.__class__.__name__, |
| 106 | + args, |
| 107 | + node.iterfield, |
| 108 | + name) |
| 109 | + else: |
| 110 | + nodedef = '%s = Node(%s(%s), name="%s")' % (name, |
| 111 | + klass.__class__.__name__, |
| 112 | + args, |
| 113 | + name) |
| 114 | + lines = [importline, comment, nodedef] |
| 115 | + |
| 116 | + if include_config: |
| 117 | + lines = [importline, "from collections import OrderedDict", comment, nodedef] |
| 118 | + lines.append('%s.config = %s' % (name, node.config)) |
| 119 | + |
| 120 | + if node.iterables is not None: |
| 121 | + lines.append('%s.iterables = %s' % (name, node.iterables)) |
| 122 | + lines.extend(_write_inputs(node)) |
| 123 | + |
| 124 | + return lines |
| 125 | + |
53 | 126 |
|
54 | 127 | class WorkflowBase(object):
|
55 | 128 | """ Define common attributes and functions for workflows and nodes
|
@@ -139,8 +212,8 @@ class Workflow(WorkflowBase):
|
139 | 212 | """Controls the setup and execution of a pipeline of processes
|
140 | 213 | """
|
141 | 214 |
|
142 |
| - def __init__(self, **kwargs): |
143 |
| - super(Workflow, self).__init__(**kwargs) |
| 215 | + def __init__(self, *args, **kwargs): |
| 216 | + super(Workflow, self).__init__(* args, **kwargs) |
144 | 217 | self._graph = nx.DiGraph()
|
145 | 218 |
|
146 | 219 | # PUBLIC API
|
@@ -470,6 +543,82 @@ def write_hierarchical_dotfile(self, dotfilename=None, colored=True,
|
470 | 543 | else:
|
471 | 544 | logger.info(dotstr)
|
472 | 545 |
|
| 546 | + def export(self, filename=None, prefix="output", format="python", include_config=False): |
| 547 | + """Export object into a different format |
| 548 | +
|
| 549 | + Parameters |
| 550 | + ---------- |
| 551 | + filename: string |
| 552 | + file to save the code to; overrides prefix |
| 553 | +
|
| 554 | + prefix: string |
| 555 | + prefix to use for output file |
| 556 | +
|
| 557 | + format: string |
| 558 | + one of "python" |
| 559 | + |
| 560 | + include_config: boolean |
| 561 | + whether to include node and workflow config values |
| 562 | + """ |
| 563 | + formats = ["python"] |
| 564 | + if format not in formats: |
| 565 | + raise ValueError('format must be one of: %s' % '|'.join(formats)) |
| 566 | + flatgraph = self._create_flat_graph() |
| 567 | + nodes = nx.topological_sort(flatgraph) |
| 568 | + |
| 569 | + lines = ['# Workflow'] |
| 570 | + importlines = ['from nipype.pipeline.engine import Workflow, Node, MapNode'] |
| 571 | + functions = {} |
| 572 | + if format == "python": |
| 573 | + connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name |
| 574 | + connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' % self.name |
| 575 | + wfdef = '%s = Workflow("%s")' % (self.name, self.name) |
| 576 | + lines.append(wfdef) |
| 577 | + if include_config: |
| 578 | + lines.append('%s.config = %s' % (self.name, self.config)) |
| 579 | + for idx, node in enumerate(nodes): |
| 580 | + nodename = node.fullname.replace('.', '_') |
| 581 | + # write nodes |
| 582 | + nodelines = format_node(node, format='python', include_config=include_config) |
| 583 | + for line in nodelines: |
| 584 | + if line.startswith('from'): |
| 585 | + if line not in importlines: |
| 586 | + importlines.append(line) |
| 587 | + else: |
| 588 | + lines.append(line) |
| 589 | + # write connections |
| 590 | + for u, _, d in flatgraph.in_edges_iter(nbunch=node, |
| 591 | + data=True): |
| 592 | + for cd in d['connect']: |
| 593 | + if isinstance(cd[0], tuple): |
| 594 | + args = list(cd[0]) |
| 595 | + if args[1] in functions: |
| 596 | + funcname = functions[args[1]] |
| 597 | + else: |
| 598 | + func = create_function_from_source(args[1]) |
| 599 | + funcname = [name for name in func.func_globals if name != '__builtins__'][0] |
| 600 | + functions[args[1]] = funcname |
| 601 | + args[1] = funcname |
| 602 | + args = tuple([arg for arg in args if arg]) |
| 603 | + line = connect_template % (u.fullname.replace('.','_'), args, |
| 604 | + nodename, cd[1]) |
| 605 | + line = line.replace("'%s'" % funcname, funcname) |
| 606 | + lines.append(line) |
| 607 | + else: |
| 608 | + lines.append(connect_template2 % (u.fullname.replace('.','_'), cd[0], |
| 609 | + nodename, cd[1])) |
| 610 | + functionlines = ['# Functions'] |
| 611 | + for function in functions: |
| 612 | + functionlines.append(cPickle.loads(function).rstrip()) |
| 613 | + all_lines = importlines + functionlines + lines |
| 614 | + |
| 615 | + if not filename: |
| 616 | + filename = '%s%s.py' % (prefix, self.name) |
| 617 | + with open(filename, 'wt') as fp: |
| 618 | + #fp.writelines('\n'.join([line.replace('\n', '\\n') for line in all_lines])) |
| 619 | + fp.writelines('\n'.join(all_lines)) |
| 620 | + return all_lines |
| 621 | + |
473 | 622 | def run(self, plugin=None, plugin_args=None, updatehash=False):
|
474 | 623 | """ Execute the workflow
|
475 | 624 |
|
@@ -1032,6 +1181,7 @@ def help(self):
|
1032 | 1181 | """ Print interface help"""
|
1033 | 1182 | self._interface.help()
|
1034 | 1183 |
|
| 1184 | + |
1035 | 1185 | def hash_exists(self, updatehash=False):
|
1036 | 1186 | # Get a dictionary with hashed filenames and a hashvalue
|
1037 | 1187 | # of the dictionary itself.
|
|
0 commit comments