Skip to content

Commit bc48a74

Browse files
author
Shoshana Berleant
committed
refactor/allow lists of 3d label files
1 parent c1b3a12 commit bc48a74

File tree

2 files changed

+58
-52
lines changed

2 files changed

+58
-52
lines changed

nipype/algorithms/stats.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919

2020
from .. import logging
2121
from ..interfaces.base import (traits, TraitedSpec, BaseInterface,
22-
BaseInterfaceInputSpec, File)
22+
BaseInterfaceInputSpec, File, InputMultiPath)
2323
IFLOG = logging.getLogger('interface')
2424

2525
class SignalExtractionInputSpec(BaseInterfaceInputSpec):
2626
in_file = File(exists=True, mandatory=True, desc='4-D fMRI nii file')
27-
label_file = File(exists=True, mandatory=True,
28-
desc='a 3-D label image, with 0 denoting background, or '
29-
'a 4-D file of probability maps.')
27+
label_files = InputMultiPath(File(exists=True), mandatory=True,
28+
desc='a 3-D label image, with 0 denoting '
29+
'background, or a list of 3-D probability '
30+
'maps (one per label) or the equivalent 4D '
31+
'file.')
3032
class_labels = traits.List(mandatory=True,
3133
desc='Human-readable labels for each segment '
3234
'in the label file, in order. The length of '
@@ -37,10 +39,6 @@ class SignalExtractionInputSpec(BaseInterfaceInputSpec):
3739
out_file = File('signals.tsv', usedefault=True, exists=False,
3840
mandatory=False, desc='The name of the file to output to. '
3941
'signals.tsv by default')
40-
stat = traits.Enum(('mean',), mandatory=False, default='mean',
41-
usedefault=True,
42-
desc='The stat you wish to calculate on each segment. '
43-
'The default is finding the mean')
4442
detrend = traits.Bool(False, usedefault=True, mandatory=False,
4543
desc='If True, perform detrending using nilearn.')
4644

@@ -56,51 +54,59 @@ class SignalExtraction(BaseInterface):
5654
5755
>>> seinterface = SignalExtraction()
5856
>>> seinterface.inputs.in_file = 'functional.nii'
59-
>>> seinterface.inputs.in_file = 'segmentation0.nii.gz'
57+
>>> seinterface.inputs.label_files = 'segmentation0.nii.gz'
6058
>>> seinterface.inputs.out_file = 'means.tsv'
6159
>>> segments = ['CSF', 'gray', 'white']
6260
>>> seinterface.inputs.class_labels = segments
63-
>>> seinterface.inputs.stat = 'mean'
61+
>>> seinterface.inputs.detrend = True
6462
'''
6563
input_spec = SignalExtractionInputSpec
6664
output_spec = SignalExtractionOutputSpec
6765

6866
def _run_interface(self, runtime):
69-
import nilearn.input_data as nl
67+
masker = self._process_inputs()
68+
69+
region_signals = masker.fit_transform(self.inputs.in_file)
70+
71+
output = np.vstack((self.inputs.class_labels, region_signals.astype(str)))
7072

71-
ins = self.inputs
72-
labels = nb.load(ins.label_file)
73-
74-
if ins.stat == 'mean': # always true for now
75-
if len(labels.get_data().shape) == 3:
76-
region_signals = self._3d_label_handler(nl, labels)
77-
else:
78-
region_signals = self._4d_label_handler(nl, labels)
79-
num_labels_found = region_signals.shape[1]
80-
if len(ins.class_labels) != num_labels_found:
81-
raise ValueError('The length of class_labels {} does not '
82-
'match the number of regions {} found in '
83-
'label_file {}'.format(ins.class_labels,
84-
num_labels_found,
85-
ins.label_file))
86-
87-
output = np.vstack((ins.class_labels, region_signals.astype(str)))
88-
89-
# save output
90-
np.savetxt(ins.out_file, output, fmt=b'%s', delimiter='\t')
73+
# save output
74+
np.savetxt(self.inputs.out_file, output, fmt=b'%s', delimiter='\t')
9175
return runtime
9276

93-
def _3d_label_handler(self, nl, labels):
94-
nlmasker = nl.NiftiLabelsMasker(labels, detrend=self.inputs.detrend)
95-
nlmasker.fit()
96-
region_signals = nlmasker.transform_single_imgs(self.inputs.in_file)
97-
return region_signals
98-
99-
def _4d_label_handler(self, nl, labels):
100-
nmmasker = nl.NiftiMapsMasker(labels, detrend=self.inputs.detrend)
101-
nmmasker.fit()
102-
region_signals = nmmasker.transform_single_imgs(self.inputs.in_file)
103-
return region_signals
77+
def _process_inputs(self):
78+
''' validate and process inputs into useful form '''
79+
80+
import nilearn.input_data as nl
81+
82+
# determine form of label files, choose appropriate nilearn masker
83+
if len(self.inputs.label_files) > 1: # list of 3D nifti images
84+
masker = nl.NiftiMapsMasker(self.inputs.label_files)
85+
n_labels = len(self.inputs.label_files)
86+
else: # list of size one, containing either a 3d or a 4d file
87+
label_data = nb.load(self.inputs.label_files[0])
88+
if len(label_data.shape) == 4: # 4d file
89+
masker = nl.NiftiMapsMasker(label_data)
90+
n_labels = label_data.shape[3]
91+
else: # 3d file
92+
if np.amax(label_data) > 1: # 3d label file
93+
masker = nl.NiftiLabelsMasker(label_data)
94+
# assuming consecutive positive integers for regions
95+
n_labels = np.amax(label_data.get_data())
96+
else: # most probably a single probability map for one label
97+
masker = nl.NiftiMapsMasker(label_data)
98+
n_labels = 1
99+
100+
# check label list size
101+
if len(self.inputs.class_labels) != n_labels:
102+
raise ValueError('The length of class_labels {} does not '
103+
'match the number of regions {} found in '
104+
'label_files {}'.format(self.inputs.class_labels,
105+
n_labels,
106+
self.inputs.label_files))
107+
108+
masker.set_params(detrend=self.inputs.detrend)
109+
return masker
104110

105111
def _list_outputs(self):
106112
outputs = self._outputs().get()

nipype/algorithms/tests/test_stats.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TestSignalExtraction(unittest.TestCase):
2222

2323
filenames = {
2424
'in_file': 'fmri.nii',
25-
'label_file': 'labels.nii',
25+
'label_files': 'labels.nii',
2626
'4d_label_file': '4dlabels.nii',
2727
'out_file': 'signals.tsv'
2828
}
@@ -34,13 +34,13 @@ def setUp(self):
3434
os.chdir(self.temp_dir)
3535

3636
utils.save_toy_nii(self.fake_fmri_data, self.filenames['in_file'])
37-
utils.save_toy_nii(self.fake_label_data, self.filenames['label_file'])
37+
utils.save_toy_nii(self.fake_label_data, self.filenames['label_files'])
3838

3939
@skipif(no_nilearn)
4040
def test_signal_extraction(self):
4141
# run
4242
stats.SignalExtraction(in_file=self.filenames['in_file'],
43-
label_file=self.filenames['label_file'],
43+
label_files=self.filenames['label_files'],
4444
class_labels=self.labels).run()
4545
# assert
4646
self.assert_expected_output(self.base_wanted)
@@ -50,28 +50,28 @@ def test_signal_extraction(self):
5050
def test_signal_extraction_bad_label_list(self):
5151
# run
5252
stats.SignalExtraction(in_file=self.filenames['in_file'],
53-
label_file=self.filenames['label_file'],
53+
label_files=self.filenames['label_files'],
5454
class_labels=['bad']).run()
5555

5656
@skipif(no_nilearn)
5757
def test_signal_extraction_equiv_4d(self):
5858
self._test_4d_label(self.base_wanted, self.fake_equiv_4d_label_data)
5959

6060
@skipif(no_nilearn)
61-
def test_signal_extraction_4d_(self):
61+
def test_signal_extraction_4d(self):
6262
self._test_4d_label([[-5.0652173913, -5.44565217391, 5.50543478261],
63-
[0, -2, .5],
64-
[-.3333333, -1, 2.5],
65-
[0, -2, .5],
66-
[-1.3333333, -5, 1]], self.fake_4d_label_data)
63+
[-7.02173913043, 11.1847826087, -4.33152173913],
64+
[-19.0869565217, 21.2391304348, -4.57608695652],
65+
[5.19565217391, -3.66304347826, -1.51630434783],
66+
[-12.0, 3., 0.5]], self.fake_4d_label_data)
6767

6868
def _test_4d_label(self, wanted, fake_labels):
6969
# setup
7070
utils.save_toy_nii(fake_labels, self.filenames['4d_label_file'])
7171

7272
# run
7373
stats.SignalExtraction(in_file=self.filenames['in_file'],
74-
label_file=self.filenames['4d_label_file'],
74+
label_files=self.filenames['4d_label_file'],
7575
class_labels=self.labels).run()
7676

7777
self.assert_expected_output(wanted)

0 commit comments

Comments
 (0)