@@ -75,15 +75,12 @@ class SignalExtraction(BaseInterface):
75
75
output_spec = SignalExtractionOutputSpec
76
76
77
77
def _run_interface (self , runtime ):
78
- masker , global_masker = self ._process_inputs ()
78
+ maskers = self ._process_inputs ()
79
79
80
- region_signals = masker .fit_transform (self .inputs .in_file )
81
-
82
- if global_masker != None :
83
- self .inputs .class_labels .insert (0 , 'global' )
84
- global_masker .fit ()
85
- global_signal = global_masker .transform (self .inputs .in_file )
86
- region_signals = np .hstack ((global_signal , region_signals ))
80
+ signals = []
81
+ for masker in maskers :
82
+ signals .append (masker .fit_transform (self .inputs .in_file ))
83
+ region_signals = np .hstack (signals )
87
84
88
85
output = np .vstack ((self .inputs .class_labels , region_signals .astype (str )))
89
86
@@ -92,21 +89,26 @@ def _run_interface(self, runtime):
92
89
return runtime
93
90
94
91
def _process_inputs (self ):
95
- ''' validate and process inputs into useful form '''
92
+ ''' validate and process inputs into useful form.
93
+ Returns a list of nilearn maskers and the list of corresponding label names.'''
96
94
import nilearn .input_data as nl
97
95
import nilearn .image as nli
98
96
99
- label_datas = [ nb . load ( nifti ) for nifti in self .inputs .label_files ]
100
- label_data = nli . concat_imgs ( label_datas )
97
+ label_data = nli . concat_imgs ( self .inputs .label_files )
98
+ maskers = []
101
99
102
100
# determine form of label files, choose appropriate nilearn masker
103
- if len ( label_datas ) == 1 and np .amax (label_data .get_data ()) > 1 : # 3d label file
101
+ if np .amax (label_data .get_data ()) > 1 : # 3d label file
104
102
n_labels = np .amax (label_data .get_data ())
105
- masker = nl .NiftiLabelsMasker (label_data )
106
- else : # one 4d file
103
+ maskers . append ( nl .NiftiLabelsMasker (label_data ) )
104
+ else : # 4d labels
107
105
n_labels = label_data .get_data ().shape [3 ]
108
- masker = nl .NiftiMapsMasker (label_data )
109
- masker .set_params (detrend = self .inputs .detrend )
106
+ if self .inputs .incl_shared_variance : # 4d labels, independent computation
107
+ for img in nli .iter_img (label_data ):
108
+ sortof_4d_img = nb .Nifti1Image (img .get_data ()[:, :, :, np .newaxis ], np .eye (4 ))
109
+ maskers .append (nl .NiftiMapsMasker (sortof_4d_img ))
110
+ else : # 4d labels, one computation fitting all
111
+ maskers .append (nl .NiftiMapsMasker (label_data ))
110
112
111
113
# check label list size
112
114
if len (self .inputs .class_labels ) != n_labels :
@@ -116,14 +118,18 @@ def _process_inputs(self):
116
118
n_labels ,
117
119
self .inputs .label_files ))
118
120
119
- global_masker = None
120
121
if self .inputs .include_global :
121
122
global_label_data = label_data .get_data ().clip (0 , 1 ).sum (axis = 3 )
122
123
global_label_data = global_label_data [:, :, :, np .newaxis ] # add back 4th dimension
123
124
global_label_data = nb .Nifti1Image (global_label_data , np .eye (4 ))
124
125
global_masker = nl .NiftiMapsMasker (global_label_data , detrend = self .inputs .detrend )
126
+ maskers .insert (0 , global_masker )
127
+ self .inputs .class_labels .insert (0 , 'global' )
128
+
129
+ for masker in maskers :
130
+ masker .set_params (detrend = self .inputs .detrend )
125
131
126
- return masker , global_masker
132
+ return maskers
127
133
128
134
def _list_outputs (self ):
129
135
outputs = self ._outputs ().get ()
0 commit comments