Skip to content

Commit fb0d6b9

Browse files
author
Shoshana Berleant
committed
nilearn method
1 parent b11c5f5 commit fb0d6b9

File tree

2 files changed

+68
-30
lines changed

2 files changed

+68
-30
lines changed

nipype/algorithms/stats.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
absolute_import)
1616

1717
import numpy as np
18+
import nibabel as nb
1819

1920
from .. import logging
2021
from ..interfaces.base import (traits, TraitedSpec, BaseInterface,
@@ -68,9 +69,13 @@ def _run_interface(self, runtime):
6869
import nilearn.input_data as nl
6970

7071
ins = self.inputs
72+
labels = nb.load(ins.label_file)
7173

7274
if ins.stat == 'mean': # always true for now
73-
region_signals = self._3d_label_handler(nl, ins)
75+
if len(labels.get_data().shape) == 3:
76+
region_signals = self._3d_label_handler(nl, labels)
77+
else:
78+
region_signals = self._4d_label_handler(nl, labels)
7479
num_labels_found = region_signals.shape[1]
7580
if len(ins.class_labels) != num_labels_found:
7681
raise ValueError('The length of class_labels {} does not '
@@ -85,11 +90,16 @@ def _run_interface(self, runtime):
8590
np.savetxt(ins.out_file, output, fmt=b'%s', delimiter='\t')
8691
return runtime
8792

88-
def _3d_label_handler(self, nl, ins):
89-
nlmasker = nl.NiftiLabelsMasker(ins.label_file,
90-
detrend=ins.detrend)
93+
def _3d_label_handler(self, nl, labels):
94+
nlmasker = nl.NiftiLabelsMasker(labels, detrend=self.inputs.detrend)
9195
nlmasker.fit()
92-
region_signals = nlmasker.transform_single_imgs(ins.in_file)
96+
region_signals = nlmasker.transform_single_imgs(self.inputs.in_file)
97+
return region_signals
98+
99+
def _4d_label_handler(self, nl, labels):
100+
nmmasker = nl.NiftiMapsMasker(labels, detrend=self.inputs.detrend)
101+
nmmasker.fit()
102+
region_signals = nmmasker.transform_single_imgs(self.inputs.in_file)
93103
return region_signals
94104

95105
def _list_outputs(self):

nipype/algorithms/tests/test_stats.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class TestSignalExtraction(unittest.TestCase):
2626
'4d_label_file': '4dlabels.nii',
2727
'out_file': 'signals.tsv'
2828
}
29+
labels = ['csf', 'gray', 'white']
2930

3031
def setUp(self):
3132
self.orig_dir = os.getcwd()
@@ -37,44 +38,57 @@ def setUp(self):
3738

3839
@skipif(no_nilearn)
3940
def test_signal_extraction(self):
40-
# setup
41-
wanted = [[-2.33333, 2, .5],
42-
[0, -2, .5],
43-
[-.3333333, -1, 2.5],
44-
[0, -2, .5],
45-
[-1.3333333, -5, 1]]
46-
num_timepoints_wanted = self.fake_fmri_data.shape[3]
4741
# run
48-
49-
labels_wanted = ['csf', 'gray', 'white']
5042
stats.SignalExtraction(in_file=self.filenames['in_file'],
5143
label_file=self.filenames['label_file'],
52-
class_labels=labels_wanted).run()
44+
class_labels=self.labels).run()
5345
# assert
54-
with open(self.filenames['out_file'], 'r') as output:
55-
got = [line.split() for line in output]
56-
labels_got = got.pop(0) # remove header
57-
assert_equal(labels_got, labels_wanted)
58-
assert_equal(len(got), num_timepoints_wanted)
59-
# convert from string to float
60-
got = [[float(num) for num in row] for row in got]
61-
for i, time in enumerate(got):
62-
assert_equal(len(labels_wanted), len(time))
63-
for j, segment in enumerate(time):
64-
assert_almost_equal(segment, wanted[i][j], decimal=1)
46+
self.assert_expected_output(self.base_wanted)
6547

6648
@skipif(no_nilearn)
6749
@raises(ValueError)
68-
def test_signal_extraction_bad_class_labels(self):
50+
def test_signal_extraction_bad_label_list(self):
6951
# run
7052
stats.SignalExtraction(in_file=self.filenames['in_file'],
7153
label_file=self.filenames['label_file'],
7254
class_labels=['bad']).run()
7355

7456
@skipif(no_nilearn)
75-
def test_signal_extraction_4d_label(self):
57+
def test_signal_extraction_equiv_4d(self):
58+
self._test_4d_label(self.base_wanted, self.fake_equiv_4d_label_data)
59+
60+
def test_signal_extraction_4d_(self):
61+
self._test_4d_label([[-5.0652173913, -5.44565217391, 5.50543478261],
62+
[0, -2, .5],
63+
[-.3333333, -1, 2.5],
64+
[0, -2, .5],
65+
[-1.3333333, -5, 1]], self.fake_4d_label_data)
66+
67+
def _test_4d_label(self, wanted, fake_labels):
7668
# setup
77-
utils.save_toy_nii(self.fake_4d_label_data, self.filenames['4d_label_file'])
69+
utils.save_toy_nii(fake_labels, self.filenames['4d_label_file'])
70+
71+
# run
72+
stats.SignalExtraction(in_file=self.filenames['in_file'],
73+
label_file=self.filenames['4d_label_file'],
74+
class_labels=self.labels).run()
75+
76+
self.assert_expected_output(wanted)
77+
78+
def assert_expected_output(self, wanted):
79+
with open(self.filenames['out_file'], 'r') as output:
80+
got = [line.split() for line in output]
81+
labels_got = got.pop(0) # remove header
82+
assert_equal(labels_got, self.labels)
83+
assert_equal(len(got), self.fake_fmri_data.shape[3],
84+
'num rows and num volumes')
85+
# convert from string to float
86+
got = [[float(num) for num in row] for row in got]
87+
for i, time in enumerate(got):
88+
assert_equal(len(self.labels), len(time))
89+
for j, segment in enumerate(time):
90+
assert_almost_equal(segment, wanted[i][j], decimal=1)
91+
7892

7993
def tearDown(self):
8094
os.chdir(self.orig_dir)
@@ -99,13 +113,27 @@ def tearDown(self):
99113
[[2, 0],
100114
[1, 3]]])
101115

116+
fake_equiv_4d_label_data = np.array([[[[1., 0., 0.],
117+
[0., 0., 0.]],
118+
[[0., 0., 1.],
119+
[1., 0., 0.]]],
120+
[[[0., 1., 0.],
121+
[0., 0., 0.]],
122+
[[1., 0., 0.],
123+
[0., 0., 1.]]]])
124+
125+
base_wanted = [[-2.33333, 2, .5],
126+
[0, -2, .5],
127+
[-.3333333, -1, 2.5],
128+
[0, -2, .5],
129+
[-1.3333333, -5, 1]]
130+
102131
fake_4d_label_data = np.array([[[[0.2, 0.3, 0.5],
103132
[0.1, 0.1, 0.8]],
104133

105134
[[0.1, 0.3, 0.6],
106135
[0.3, 0.4, 0.3]]],
107136

108-
109137
[[[0.2, 0.2, 0.6],
110138
[0., 0.3, 0.7]],
111139

0 commit comments

Comments
 (0)