Skip to content

Commit e9a5bb8

Browse files
committed
Merge remote-tracking branch 'nipy/master' into mriscombine
2 parents 36fd9c1 + 5c280a4 commit e9a5bb8

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

nipype/interfaces/utility/tests/test_wrappers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from nipype.interfaces import utility
99
import nipype.pipeline.engine as pe
1010

11+
concat_sort = """\
12+
def concat_sort(in_arrays):
13+
import numpy as np
14+
all_vals = np.concatenate([arr.flatten() for arr in in_arrays])
15+
return np.sort(all_vals)
16+
"""
1117

1218
def test_function(tmpdir):
1319
os.chdir(str(tmpdir))
@@ -24,9 +30,15 @@ def gen_random_array(size):
2430
def increment_array(in_array):
2531
return in_array + 1
2632

27-
f2 = pe.MapNode(utility.Function(input_names=['in_array'], output_names=['out_array'], function=increment_array), name='increment_array', iterfield=['in_array'])
33+
f2 = pe.MapNode(utility.Function(function=increment_array), name='increment_array', iterfield=['in_array'])
2834

2935
wf.connect(f1, 'random_array', f2, 'in_array')
36+
37+
f3 = pe.Node(
38+
utility.Function(function=concat_sort),
39+
name="concat_sort")
40+
41+
wf.connect(f2, 'out', f3, 'in_arrays')
3042
wf.run()
3143

3244

nipype/interfaces/utility/wrappers.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,19 @@ class Function(IOBase):
5858
input_spec = FunctionInputSpec
5959
output_spec = DynamicTraitedSpec
6060

61-
def __init__(self, input_names, output_names, function=None, imports=None,
62-
**inputs):
61+
def __init__(self, input_names=None, output_names='out', function=None,
62+
imports=None, **inputs):
6363
"""
6464
6565
Parameters
6666
----------
6767
68-
input_names: single str or list
68+
input_names: single str or list or None
6969
names corresponding to function inputs
70+
if ``None``, derive input names from function argument names
7071
output_names: single str or list
71-
names corresponding to function outputs.
72-
has to match the number of outputs
72+
names corresponding to function outputs (default: 'out').
73+
if list of length > 1, has to match the number of outputs
7374
function : callable
7475
callable python object. must be able to execute in an
7576
isolated namespace (possibly in concert with the ``imports``
@@ -88,10 +89,18 @@ def __init__(self, input_names, output_names, function=None, imports=None,
8889
raise Exception('Interface Function does not accept '
8990
'function objects defined interactively '
9091
'in a python session')
92+
else:
93+
if input_names is None:
94+
fninfo = function.__code__
9195
elif isinstance(function, (str, bytes)):
9296
self.inputs.function_str = function
97+
if input_names is None:
98+
fninfo = create_function_from_source(
99+
function, imports).__code__
93100
else:
94101
raise Exception('Unknown type of function')
102+
if input_names is None:
103+
input_names = fninfo.co_varnames[:fninfo.co_argcount]
95104
self.inputs.on_trait_change(self._set_function_string,
96105
'function_str')
97106
self._input_names = filename_to_list(input_names)
@@ -106,10 +115,18 @@ def _set_function_string(self, obj, name, old, new):
106115
if name == 'function_str':
107116
if hasattr(new, '__call__'):
108117
function_source = getsource(new)
118+
fninfo = new.__code__
109119
elif isinstance(new, (str, bytes)):
110120
function_source = new
121+
fninfo = create_function_from_source(
122+
new, self.imports).__code__
111123
self.inputs.trait_set(trait_change_notify=False,
112124
**{'%s' % name: function_source})
125+
# Update input traits
126+
input_names = fninfo.co_varnames[:fninfo.co_argcount]
127+
new_names = set(input_names) - set(self._input_names)
128+
add_traits(self.inputs, list(new_names))
129+
self._input_names.extend(new_names)
113130

114131
def _add_output_traits(self, base):
115132
undefined_traits = {}

0 commit comments

Comments
 (0)