Skip to content

Commit b737310

Browse files
committed
ENH: Auto-derive input_names in Function
1 parent d387481 commit b737310

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

nipype/interfaces/utility/tests/test_wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def gen_random_array(size):
2424
def increment_array(in_array):
2525
return in_array + 1
2626

27-
f2 = pe.MapNode(utility.Function(input_names=['in_array'], output_names=['out_array'], function=increment_array), name='increment_array', iterfield=['in_array'])
27+
f2 = pe.MapNode(utility.Function(function=increment_array), name='increment_array', iterfield=['in_array'])
2828

2929
wf.connect(f1, 'random_array', f2, 'in_array')
3030
wf.run()

nipype/interfaces/utility/wrappers.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class Function(IOBase):
5858
input_spec = FunctionInputSpec
5959
output_spec = DynamicTraitedSpec
6060

61-
def __init__(self, input_names, output_names, function=None, imports=None,
62-
**inputs):
61+
def __init__(self, input_names=None, output_names='out', function=None,
62+
imports=None, **inputs):
6363
"""
6464
6565
Parameters
@@ -88,10 +88,18 @@ def __init__(self, input_names, output_names, function=None, imports=None,
8888
raise Exception('Interface Function does not accept '
8989
'function objects defined interactively '
9090
'in a python session')
91-
elif isinstance(function, (str, bytes)):
91+
else:
92+
if inputs is None:
93+
fninfo = function.func_code
94+
elif isinstance(function, string_types):
9295
self.inputs.function_str = function
96+
if inputs is None:
97+
fninfo = create_function_from_source(
98+
function, imports).func_code
9399
else:
94100
raise Exception('Unknown type of function')
101+
if inputs is None:
102+
inputs = fninfo.co_varnames[:fninfo.co_argcount]
95103
self.inputs.on_trait_change(self._set_function_string,
96104
'function_str')
97105
self._input_names = filename_to_list(input_names)
@@ -106,10 +114,18 @@ def _set_function_string(self, obj, name, old, new):
106114
if name == 'function_str':
107115
if hasattr(new, '__call__'):
108116
function_source = getsource(new)
109-
elif isinstance(new, (str, bytes)):
117+
fninfo = new.func_code
118+
elif isinstance(new, string_types):
110119
function_source = new
120+
fninfo = create_function_from_source(
121+
new, self.imports).func_code
111122
self.inputs.trait_set(trait_change_notify=False,
112123
**{'%s' % name: function_source})
124+
# Update input traits
125+
input_names = fninfo.co_varnames[:fninfo.co_argcount]
126+
new_names = set(input_names) - set(self._input_names)
127+
add_traits(self.inputs, list(new_names))
128+
self._input_names = new_names
113129

114130
def _add_output_traits(self, base):
115131
undefined_traits = {}

0 commit comments

Comments
 (0)