@@ -62,15 +62,22 @@ class SignalExtraction(BaseInterface):
62
62
>>> segments = ['CSF', 'gray', 'white']
63
63
>>> seinterface.inputs.class_labels = segments
64
64
>>> seinterface.inputs.detrend = True
65
+ >>> seinterface.inputs.include_global = True
65
66
'''
66
67
input_spec = SignalExtractionInputSpec
67
68
output_spec = SignalExtractionOutputSpec
68
69
69
70
def _run_interface (self , runtime ):
70
- masker = self ._process_inputs ()
71
+ masker , global_masker = self ._process_inputs ()
71
72
72
73
region_signals = masker .fit_transform (self .inputs .in_file )
73
74
75
+ if global_masker :
76
+ self .inputs .class_labels .insert (0 , 'global' )
77
+ global_masker .fit ()
78
+ global_signal = global_masker .transform (self .inputs .in_file )
79
+ region_signals = np .hstack ((global_signal , region_signals ))
80
+
74
81
output = np .vstack ((self .inputs .class_labels , region_signals .astype (str )))
75
82
76
83
# save output
@@ -87,18 +94,19 @@ def _process_inputs(self):
87
94
masker = nl .NiftiMapsMasker (self .inputs .label_files )
88
95
n_labels = len (self .inputs .label_files )
89
96
else : # list of size one, containing either a 3d or a 4d file
90
- label_data = nb .load (self .inputs .label_files [0 ])
91
- if len (label_data .shape ) == 4 : # 4d file
92
- masker = nl .NiftiMapsMasker (label_data )
93
- n_labels = label_data .shape [3 ]
97
+ self . label_data = nb .load (self .inputs .label_files [0 ])
98
+ if len (self . label_data .shape ) == 4 : # 4d file
99
+ masker = nl .NiftiMapsMasker (self . label_data )
100
+ n_labels = self . label_data .shape [3 ]
94
101
else : # 3d file
95
- if np .amax (label_data .get_data ()) > 1 : # 3d label file
96
- masker = nl .NiftiLabelsMasker (label_data )
102
+ if np .amax (self . label_data .get_data ()) > 1 : # 3d label file
103
+ masker = nl .NiftiLabelsMasker (self . label_data )
97
104
# assuming consecutive positive integers for regions
98
- n_labels = np .amax (label_data .get_data ())
105
+ n_labels = np .amax (self . label_data .get_data ())
99
106
else : # most probably a single probability map for one label
100
- masker = nl .NiftiMapsMasker (label_data )
107
+ masker = nl .NiftiMapsMasker (self . label_data )
101
108
n_labels = 1
109
+ masker .set_params (detrend = self .inputs .detrend )
102
110
103
111
# check label list size
104
112
if len (self .inputs .class_labels ) != n_labels :
@@ -108,8 +116,19 @@ def _process_inputs(self):
108
116
n_labels ,
109
117
self .inputs .label_files ))
110
118
111
- masker .set_params (detrend = self .inputs .detrend )
112
- return masker
119
+ if self .inputs .include_global :
120
+ all_ones_mask = nb .Nifti1Image (np .ones (self ._label_data_shape ()), np .eye (4 ))
121
+ global_masker = nl .NiftiLabelsMasker (all_ones_mask , detrend = self .inputs .detrend )
122
+ else :
123
+ global_masker = False
124
+
125
+ return masker , global_masker
126
+
127
+ def _label_data_shape (self ):
128
+ if self .label_data :
129
+ return self .label_data .shape
130
+ else :
131
+ return nb .load (self .inputs .label_files [0 ]).shape
113
132
114
133
def _list_outputs (self ):
115
134
outputs = self ._outputs ().get ()
0 commit comments