Skip to content

Commit 97dd3aa

Browse files
author
Shoshana Berleant
committed
actual global signal calculation
1 parent ab1e5f1 commit 97dd3aa

File tree

2 files changed

+43
-41
lines changed

2 files changed

+43
-41
lines changed

nipype/interfaces/nilearn.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _run_interface(self, runtime):
7373

7474
region_signals = masker.fit_transform(self.inputs.in_file)
7575

76-
if global_masker:
76+
if global_masker != None:
7777
self.inputs.class_labels.insert(0, 'global')
7878
global_masker.fit()
7979
global_signal= global_masker.transform(self.inputs.in_file)
@@ -87,26 +87,19 @@ def _run_interface(self, runtime):
8787

8888
def _process_inputs(self):
8989
''' validate and process inputs into useful form '''
90-
9190
import nilearn.input_data as nl
91+
import nilearn.image as nli
92+
93+
label_datas = [nb.load(nifti) for nifti in self.inputs.label_files]
94+
label_data = nli.concat_imgs(label_datas)
9295

9396
# determine form of label files, choose appropriate nilearn masker
94-
if len(self.inputs.label_files) > 1: # list of 3D nifti images
95-
masker = nl.NiftiMapsMasker(self.inputs.label_files)
96-
n_labels = len(self.inputs.label_files)
97-
else: # list of size one, containing either a 3d or a 4d file
98-
self.label_data = nb.load(self.inputs.label_files[0])
99-
if len(self.label_data.shape) == 4: # 4d file
100-
masker = nl.NiftiMapsMasker(self.label_data)
101-
n_labels = self.label_data.shape[3]
102-
else: # 3d file
103-
if np.amax(self.label_data.get_data()) > 1: # 3d label file
104-
masker = nl.NiftiLabelsMasker(self.label_data)
105-
# assuming consecutive positive integers for regions
106-
n_labels = np.amax(self.label_data.get_data())
107-
else: # most probably a single probability map for one label
108-
masker = nl.NiftiMapsMasker(self.label_data)
109-
n_labels = 1
97+
if len(label_datas) == 1 and np.amax(label_data.get_data()) > 1: # 3d label file
98+
n_labels = np.amax(label_data.get_data())
99+
masker = nl.NiftiLabelsMasker(label_data)
100+
else: # one 4d file
101+
n_labels = label_data.get_data().shape[3]
102+
masker = nl.NiftiMapsMasker(label_data)
110103
masker.set_params(detrend=self.inputs.detrend)
111104

112105
# check label list size
@@ -117,20 +110,15 @@ def _process_inputs(self):
117110
n_labels,
118111
self.inputs.label_files))
119112

113+
global_masker = None
120114
if self.inputs.include_global:
121-
all_ones_mask = nb.Nifti1Image(np.ones(self._label_data_shape()), np.eye(4))
122-
global_masker = nl.NiftiLabelsMasker(all_ones_mask, detrend=self.inputs.detrend)
123-
else:
124-
global_masker = False
115+
global_label_data = label_data.get_data().clip(0, 1).sum(axis=3)
116+
global_label_data = global_label_data[:, :, :, np.newaxis] # add back 4th dimension
117+
global_label_data = nb.Nifti1Image(global_label_data, np.eye(4))
118+
global_masker = nl.NiftiMapsMasker(global_label_data, detrend=self.inputs.detrend)
125119

126120
return masker, global_masker
127121

128-
def _label_data_shape(self):
129-
if self.label_data:
130-
return self.label_data.shape
131-
else:
132-
return nb.load(self.inputs.label_files[0]).shape
133-
134122
def _list_outputs(self):
135123
outputs = self._outputs().get()
136124
outputs['out_file'] = self.inputs.out_file

nipype/interfaces/tests/test_nilearn.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class TestSignalExtraction(unittest.TestCase):
2727
'out_file': 'signals.tsv'
2828
}
2929
labels = ['csf', 'gray', 'white']
30+
global_labels = ['global'] + labels
3031

3132
def setUp(self):
3233
self.orig_dir = os.getcwd()
@@ -60,20 +61,14 @@ def test_signal_extraction_equiv_4d(self):
6061

6162
@skipif(no_nilearn)
6263
def test_signal_extraction_4d(self):
63-
self._test_4d_label([[-5.0652173913, -5.44565217391, 5.50543478261],
64-
[-7.02173913043, 11.1847826087, -4.33152173913],
65-
[-19.0869565217, 21.2391304348, -4.57608695652],
66-
[5.19565217391, -3.66304347826, -1.51630434783],
67-
[-12.0, 3., 0.5]], self.fake_4d_label_data)
64+
self._test_4d_label(self.fourd_wanted, self.fake_4d_label_data)
6865

6966
@skipif(no_nilearn)
70-
def test_signal_extraction_include_global(self):
67+
def test_signal_extraction_global(self):
7168
# wanted
72-
wanted_global = [[3./8], [-3./8], [1./8], [-7./8], [-9./8]]
69+
wanted_global = [[-4./6], [-1./6], [3./6], [-1./6], [-7./6]]
7370
for i, vals in enumerate(self.base_wanted):
7471
wanted_global[i].extend(vals)
75-
wanted_labels = ['global']
76-
wanted_labels.extend(self.labels)
7772

7873
# run
7974
iface.SignalExtraction(in_file=self.filenames['in_file'],
@@ -82,18 +77,30 @@ def test_signal_extraction_include_global(self):
8277
include_global=True).run()
8378

8479
# assert
85-
self.assert_expected_output(wanted_labels, wanted_global)
80+
self.assert_expected_output(self.global_labels, wanted_global)
81+
82+
@skipif(no_nilearn)
83+
def test_signal_extraction_4d_global(self):
84+
# wanted
85+
wanted_global = [[3./8], [-3./8], [1./8], [-7./8], [-9./8]]
86+
for i, vals in enumerate(self.fourd_wanted):
87+
wanted_global[i].extend(vals)
8688

87-
def _test_4d_label(self, wanted, fake_labels):
89+
# run
90+
self._test_4d_label(wanted_global, self.fake_4d_label_data, include_global=True)
91+
92+
def _test_4d_label(self, wanted, fake_labels, include_global=False):
8893
# setup
8994
utils.save_toy_nii(fake_labels, self.filenames['4d_label_file'])
9095

9196
# run
9297
iface.SignalExtraction(in_file=self.filenames['in_file'],
9398
label_files=self.filenames['4d_label_file'],
94-
class_labels=self.labels).run()
99+
class_labels=self.labels,
100+
include_global=include_global).run()
95101

96-
self.assert_expected_output(self.labels, wanted)
102+
wanted_labels = self.global_labels if include_global else self.labels
103+
self.assert_expected_output(wanted_labels, wanted)
97104

98105
def assert_expected_output(self, labels, wanted):
99106
with open(self.filenames['out_file'], 'r') as output:
@@ -159,3 +166,10 @@ def tearDown(self):
159166

160167
[[0.3, 0.3, 0.4],
161168
[0.3, 0.4, 0.3]]]])
169+
170+
171+
fourd_wanted = [[-5.0652173913, -5.44565217391, 5.50543478261],
172+
[-7.02173913043, 11.1847826087, -4.33152173913],
173+
[-19.0869565217, 21.2391304348, -4.57608695652],
174+
[5.19565217391, -3.66304347826, -1.51630434783],
175+
[-12.0, 3., 0.5]]

0 commit comments

Comments
 (0)