Skip to content

Commit 7286093

Browse files
author
Shoshana Berleant
committed
include global signal
1 parent 77f68cf commit 7286093

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

nipype/algorithms/stats.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,22 @@ class SignalExtraction(BaseInterface):
6262
>>> segments = ['CSF', 'gray', 'white']
6363
>>> seinterface.inputs.class_labels = segments
6464
>>> seinterface.inputs.detrend = True
65+
>>> seinterface.inputs.include_global = True
6566
'''
6667
input_spec = SignalExtractionInputSpec
6768
output_spec = SignalExtractionOutputSpec
6869

6970
def _run_interface(self, runtime):
70-
masker = self._process_inputs()
71+
masker, global_masker = self._process_inputs()
7172

7273
region_signals = masker.fit_transform(self.inputs.in_file)
7374

75+
if global_masker:
76+
self.inputs.class_labels.insert(0, 'global')
77+
global_masker.fit()
78+
global_signal= global_masker.transform(self.inputs.in_file)
79+
region_signals = np.hstack((global_signal, region_signals))
80+
7481
output = np.vstack((self.inputs.class_labels, region_signals.astype(str)))
7582

7683
# save output
@@ -87,18 +94,19 @@ def _process_inputs(self):
8794
masker = nl.NiftiMapsMasker(self.inputs.label_files)
8895
n_labels = len(self.inputs.label_files)
8996
else: # list of size one, containing either a 3d or a 4d file
90-
label_data = nb.load(self.inputs.label_files[0])
91-
if len(label_data.shape) == 4: # 4d file
92-
masker = nl.NiftiMapsMasker(label_data)
93-
n_labels = label_data.shape[3]
97+
self.label_data = nb.load(self.inputs.label_files[0])
98+
if len(self.label_data.shape) == 4: # 4d file
99+
masker = nl.NiftiMapsMasker(self.label_data)
100+
n_labels = self.label_data.shape[3]
94101
else: # 3d file
95-
if np.amax(label_data.get_data()) > 1: # 3d label file
96-
masker = nl.NiftiLabelsMasker(label_data)
102+
if np.amax(self.label_data.get_data()) > 1: # 3d label file
103+
masker = nl.NiftiLabelsMasker(self.label_data)
97104
# assuming consecutive positive integers for regions
98-
n_labels = np.amax(label_data.get_data())
105+
n_labels = np.amax(self.label_data.get_data())
99106
else: # most probably a single probability map for one label
100-
masker = nl.NiftiMapsMasker(label_data)
107+
masker = nl.NiftiMapsMasker(self.label_data)
101108
n_labels = 1
109+
masker.set_params(detrend=self.inputs.detrend)
102110

103111
# check label list size
104112
if len(self.inputs.class_labels) != n_labels:
@@ -108,8 +116,19 @@ def _process_inputs(self):
108116
n_labels,
109117
self.inputs.label_files))
110118

111-
masker.set_params(detrend=self.inputs.detrend)
112-
return masker
119+
if self.inputs.include_global:
120+
all_ones_mask = nb.Nifti1Image(np.ones(self._label_data_shape()), np.eye(4))
121+
global_masker = nl.NiftiLabelsMasker(all_ones_mask, detrend=self.inputs.detrend)
122+
else:
123+
global_masker = False
124+
125+
return masker, global_masker
126+
127+
def _label_data_shape(self):
128+
if self.label_data:
129+
return self.label_data.shape
130+
else:
131+
return nb.load(self.inputs.label_files[0]).shape
113132

114133
def _list_outputs(self):
115134
outputs = self._outputs().get()

nipype/algorithms/tests/test_stats.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_signal_extraction_4d(self):
6565
[5.19565217391, -3.66304347826, -1.51630434783],
6666
[-12.0, 3., 0.5]], self.fake_4d_label_data)
6767

68-
@skipif(True)
6968
@skipif(no_nilearn)
7069
def test_signal_extraction_include_global(self):
7170
# wanted

0 commit comments

Comments
 (0)