Skip to content

Commit 393ab5b

Browse files
authored
Merge pull request #347 from oesteban/enh/keyselect-interface
ENH: Add a ``KeySelect`` interface
2 parents ffc2555 + c37dda4 commit 393ab5b

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

niworkflows/interfaces/utility.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
"""
4+
Interfaces being evaluated before upstreaming to nipype.interfaces.utility
5+
6+
"""
7+
from __future__ import absolute_import, division, print_function, unicode_literals
8+
9+
from nipype.interfaces.io import add_traits
10+
from nipype.interfaces.base import (
11+
InputMultiObject, Str, DynamicTraitedSpec, BaseInterface, isdefined
12+
)
13+
14+
15+
class KeySelectInputSpec(DynamicTraitedSpec):
16+
key = Str(mandatory=True, desc='selective key')
17+
keys = InputMultiObject(Str, mandatory=True, min=1, desc='index of keys')
18+
19+
20+
class KeySelectOutputSpec(DynamicTraitedSpec):
21+
key = Str(desc='propagates selected key')
22+
23+
24+
class KeySelect(BaseInterface):
25+
"""
26+
An interface that operates similarly to an OrderedDict
27+
28+
>>> ks = KeySelect(keys=['MNI152NLin6Asym', 'MNI152Lin', 'fsaverage'],
29+
... fields=['field1', 'field2', 'field3'])
30+
>>> ks.inputs.field1 = ['fsl', 'mni', 'freesurfer']
31+
>>> ks.inputs.field2 = ['volume', 'volume', 'surface']
32+
>>> ks.inputs.field3 = [True, False, False]
33+
>>> ks.inputs.key = 'MNI152Lin'
34+
>>> ks.run().outputs
35+
<BLANKLINE>
36+
field1 = mni
37+
field2 = volume
38+
field3 = False
39+
key = MNI152Lin
40+
<BLANKLINE>
41+
42+
>>> ks = KeySelect(fields=['field1', 'field2', 'field3'])
43+
>>> ks.inputs.keys=['MNI152NLin6Asym', 'MNI152Lin', 'fsaverage']
44+
>>> ks.inputs.field1 = ['fsl', 'mni', 'freesurfer']
45+
>>> ks.inputs.field2 = ['volume', 'volume', 'surface']
46+
>>> ks.inputs.field3 = [True, False, False]
47+
>>> ks.inputs.key = 'MNI152Lin'
48+
>>> ks.run().outputs
49+
<BLANKLINE>
50+
field1 = mni
51+
field2 = volume
52+
field3 = False
53+
key = MNI152Lin
54+
<BLANKLINE>
55+
56+
>>> ks.inputs.field1 = ['fsl', 'mni', 'freesurfer']
57+
>>> ks.inputs.field2 = ['volume', 'volume', 'surface']
58+
>>> ks.inputs.field3 = [True, False, False]
59+
>>> ks.inputs.key = 'fsaverage'
60+
>>> ks.run().outputs
61+
<BLANKLINE>
62+
field1 = freesurfer
63+
field2 = surface
64+
field3 = False
65+
key = fsaverage
66+
<BLANKLINE>
67+
68+
>>> ks.inputs.field1 = ['fsl', 'mni', 'freesurfer']
69+
>>> ks.inputs.field2 = ['volume', 'volume', 'surface']
70+
>>> ks.inputs.field3 = [True, False] # doctest: +IGNORE_EXCEPTION_DETAIL
71+
Traceback (most recent call last):
72+
ValueError: Trying to set an invalid value
73+
74+
>>> ks.inputs.key = 'MNINLin2009cAsym'
75+
Traceback (most recent call last):
76+
ValueError: Selected key "MNINLin2009cAsym" not found in the index
77+
78+
>>> ks = KeySelect(fields=['field1', 'field2', 'field3'])
79+
>>> ks.inputs.keys=['MNI152NLin6Asym']
80+
>>> ks.inputs.field1 = ['fsl']
81+
>>> ks.inputs.field2 = ['volume']
82+
>>> ks.inputs.field3 = [True]
83+
>>> ks.inputs.key = 'MNI152NLin6Asym'
84+
>>> ks.run().outputs
85+
<BLANKLINE>
86+
field1 = fsl
87+
field2 = volume
88+
field3 = True
89+
key = MNI152NLin6Asym
90+
<BLANKLINE>
91+
92+
"""
93+
input_spec = KeySelectInputSpec
94+
output_spec = KeySelectOutputSpec
95+
96+
def __init__(self, keys=None, fields=None, **inputs):
97+
# Call constructor
98+
super(KeySelect, self).__init__(**inputs)
99+
100+
# Handle and initiate fields
101+
if not fields:
102+
raise ValueError('A list or multiplexed fields must be provided at '
103+
'instantiation time.')
104+
if isinstance(fields, str):
105+
fields = [fields]
106+
107+
_invalid = set(self.input_spec.class_editable_traits()).intersection(fields)
108+
if _invalid:
109+
raise ValueError('Some fields are invalid (%s).' % ', '.join(_invalid))
110+
111+
self._fields = fields
112+
113+
# Attach events
114+
self.inputs.on_trait_change(self._check_len)
115+
if keys:
116+
self.inputs.keys = keys
117+
118+
# Add fields in self._fields
119+
add_traits(self.inputs, self._fields)
120+
121+
for in_field in set(self._fields).intersection(inputs.keys()):
122+
setattr(self.inputs, in_field, inputs[in_field])
123+
124+
def _check_len(self, name, new):
125+
if name == "keys":
126+
nitems = len(new)
127+
if len(set(new)) != nitems:
128+
raise ValueError('Found duplicated entries in the index of ordered keys')
129+
130+
if not isdefined(self.inputs.keys):
131+
return
132+
133+
if name == "key" and new not in self.inputs.keys:
134+
raise ValueError('Selected key "%s" not found in the index' % new)
135+
136+
if name in self._fields:
137+
if isinstance(new, str) or len(new) < 1:
138+
raise ValueError('Trying to set an invalid value (%s) for input "%s"' % (
139+
new, name))
140+
141+
if len(new) != len(self.inputs.keys):
142+
raise ValueError('Length of value (%s) for input field "%s" does not match '
143+
'the length of the indexing list.' % (new, name))
144+
145+
def _run_interface(self, runtime):
146+
return runtime
147+
148+
def _list_outputs(self):
149+
index = self.inputs.keys.index(self.inputs.key)
150+
151+
outputs = {k: getattr(self.inputs, k)[index]
152+
for k in self._fields}
153+
154+
outputs['key'] = self.inputs.key
155+
return outputs
156+
157+
def _outputs(self):
158+
base = super(KeySelect, self)._outputs()
159+
if self._fields:
160+
base = add_traits(base, self._fields)
161+
return base

0 commit comments

Comments
 (0)