Skip to content

Commit 9df8d8a

Browse files
author
Shoshana Berleant
committed
calculate linear regression individually be default
1 parent 5acb6dc commit 9df8d8a

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

nipype/interfaces/nilearn.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,12 @@ class SignalExtraction(BaseInterface):
7575
output_spec = SignalExtractionOutputSpec
7676

7777
def _run_interface(self, runtime):
78-
masker, global_masker = self._process_inputs()
78+
maskers = self._process_inputs()
7979

80-
region_signals = masker.fit_transform(self.inputs.in_file)
81-
82-
if global_masker != None:
83-
self.inputs.class_labels.insert(0, 'global')
84-
global_masker.fit()
85-
global_signal= global_masker.transform(self.inputs.in_file)
86-
region_signals = np.hstack((global_signal, region_signals))
80+
signals = []
81+
for masker in maskers:
82+
signals.append(masker.fit_transform(self.inputs.in_file))
83+
region_signals = np.hstack(signals)
8784

8885
output = np.vstack((self.inputs.class_labels, region_signals.astype(str)))
8986

@@ -92,21 +89,26 @@ def _run_interface(self, runtime):
9289
return runtime
9390

9491
def _process_inputs(self):
95-
''' validate and process inputs into useful form '''
92+
''' validate and process inputs into useful form.
93+
Returns a list of nilearn maskers and the list of corresponding label names.'''
9694
import nilearn.input_data as nl
9795
import nilearn.image as nli
9896

99-
label_datas = [nb.load(nifti) for nifti in self.inputs.label_files]
100-
label_data = nli.concat_imgs(label_datas)
97+
label_data = nli.concat_imgs(self.inputs.label_files)
98+
maskers = []
10199

102100
# determine form of label files, choose appropriate nilearn masker
103-
if len(label_datas) == 1 and np.amax(label_data.get_data()) > 1: # 3d label file
101+
if np.amax(label_data.get_data()) > 1: # 3d label file
104102
n_labels = np.amax(label_data.get_data())
105-
masker = nl.NiftiLabelsMasker(label_data)
106-
else: # one 4d file
103+
maskers.append(nl.NiftiLabelsMasker(label_data))
104+
else: # 4d labels
107105
n_labels = label_data.get_data().shape[3]
108-
masker = nl.NiftiMapsMasker(label_data)
109-
masker.set_params(detrend=self.inputs.detrend)
106+
if self.inputs.incl_shared_variance: # 4d labels, independent computation
107+
for img in nli.iter_img(label_data):
108+
sortof_4d_img = nb.Nifti1Image(img.get_data()[:, :, :, np.newaxis], np.eye(4))
109+
maskers.append(nl.NiftiMapsMasker(sortof_4d_img))
110+
else: # 4d labels, one computation fitting all
111+
maskers.append(nl.NiftiMapsMasker(label_data))
110112

111113
# check label list size
112114
if len(self.inputs.class_labels) != n_labels:
@@ -116,14 +118,18 @@ def _process_inputs(self):
116118
n_labels,
117119
self.inputs.label_files))
118120

119-
global_masker = None
120121
if self.inputs.include_global:
121122
global_label_data = label_data.get_data().clip(0, 1).sum(axis=3)
122123
global_label_data = global_label_data[:, :, :, np.newaxis] # add back 4th dimension
123124
global_label_data = nb.Nifti1Image(global_label_data, np.eye(4))
124125
global_masker = nl.NiftiMapsMasker(global_label_data, detrend=self.inputs.detrend)
126+
maskers.insert(0, global_masker)
127+
self.inputs.class_labels.insert(0, 'global')
128+
129+
for masker in maskers:
130+
masker.set_params(detrend=self.inputs.detrend)
125131

126-
return masker, global_masker
132+
return maskers
127133

128134
def _list_outputs(self):
129135
outputs = self._outputs().get()

nipype/interfaces/tests/test_nilearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def test_signal_extr_shared(self):
104104
wanted_row = []
105105
for reg in range(self.fake_4d_label_data.shape[3]):
106106
region = self.fake_4d_label_data[:, :, :, reg].flatten()
107-
wanted_row.append(np.average(volume, weights=region))
108-
wanted.append(wanted_row)
107+
wanted_row.append((volume*region).sum()/(region*region).sum())
109108

109+
wanted.append(wanted_row)
110110
# run & assert
111111
self._test_4d_label(wanted, self.fake_4d_label_data)
112112

0 commit comments

Comments
 (0)