19
19
20
20
from .. import logging
21
21
from ..interfaces .base import (traits , TraitedSpec , BaseInterface ,
22
- BaseInterfaceInputSpec , File )
22
+ BaseInterfaceInputSpec , File , InputMultiPath )
23
23
IFLOG = logging .getLogger ('interface' )
24
24
25
25
class SignalExtractionInputSpec (BaseInterfaceInputSpec ):
26
26
in_file = File (exists = True , mandatory = True , desc = '4-D fMRI nii file' )
27
- label_file = File (exists = True , mandatory = True ,
28
- desc = 'a 3-D label image, with 0 denoting background, or '
29
- 'a 4-D file of probability maps.' )
27
+ label_files = InputMultiPath (File (exists = True ), mandatory = True ,
28
+ desc = 'a 3-D label image, with 0 denoting '
29
+ 'background, or a list of 3-D probability '
30
+ 'maps (one per label) or the equivalent 4D '
31
+ 'file.' )
30
32
class_labels = traits .List (mandatory = True ,
31
33
desc = 'Human-readable labels for each segment '
32
34
'in the label file, in order. The length of '
@@ -37,10 +39,6 @@ class SignalExtractionInputSpec(BaseInterfaceInputSpec):
37
39
out_file = File ('signals.tsv' , usedefault = True , exists = False ,
38
40
mandatory = False , desc = 'The name of the file to output to. '
39
41
'signals.tsv by default' )
40
- stat = traits .Enum (('mean' ,), mandatory = False , default = 'mean' ,
41
- usedefault = True ,
42
- desc = 'The stat you wish to calculate on each segment. '
43
- 'The default is finding the mean' )
44
42
detrend = traits .Bool (False , usedefault = True , mandatory = False ,
45
43
desc = 'If True, perform detrending using nilearn.' )
46
44
@@ -56,51 +54,59 @@ class SignalExtraction(BaseInterface):
56
54
57
55
>>> seinterface = SignalExtraction()
58
56
>>> seinterface.inputs.in_file = 'functional.nii'
59
- >>> seinterface.inputs.in_file = 'segmentation0.nii.gz'
57
+ >>> seinterface.inputs.label_files = 'segmentation0.nii.gz'
60
58
>>> seinterface.inputs.out_file = 'means.tsv'
61
59
>>> segments = ['CSF', 'gray', 'white']
62
60
>>> seinterface.inputs.class_labels = segments
63
- >>> seinterface.inputs.stat = 'mean'
61
+ >>> seinterface.inputs.detrend = True
64
62
'''
65
63
input_spec = SignalExtractionInputSpec
66
64
output_spec = SignalExtractionOutputSpec
67
65
68
66
def _run_interface (self , runtime ):
69
- import nilearn .input_data as nl
67
+ masker = self ._process_inputs ()
68
+
69
+ region_signals = masker .fit_transform (self .inputs .in_file )
70
+
71
+ output = np .vstack ((self .inputs .class_labels , region_signals .astype (str )))
70
72
71
- ins = self .inputs
72
- labels = nb .load (ins .label_file )
73
-
74
- if ins .stat == 'mean' : # always true for now
75
- if len (labels .get_data ().shape ) == 3 :
76
- region_signals = self ._3d_label_handler (nl , labels )
77
- else :
78
- region_signals = self ._4d_label_handler (nl , labels )
79
- num_labels_found = region_signals .shape [1 ]
80
- if len (ins .class_labels ) != num_labels_found :
81
- raise ValueError ('The length of class_labels {} does not '
82
- 'match the number of regions {} found in '
83
- 'label_file {}' .format (ins .class_labels ,
84
- num_labels_found ,
85
- ins .label_file ))
86
-
87
- output = np .vstack ((ins .class_labels , region_signals .astype (str )))
88
-
89
- # save output
90
- np .savetxt (ins .out_file , output , fmt = b'%s' , delimiter = '\t ' )
73
+ # save output
74
+ np .savetxt (self .inputs .out_file , output , fmt = b'%s' , delimiter = '\t ' )
91
75
return runtime
92
76
93
- def _3d_label_handler (self , nl , labels ):
94
- nlmasker = nl .NiftiLabelsMasker (labels , detrend = self .inputs .detrend )
95
- nlmasker .fit ()
96
- region_signals = nlmasker .transform_single_imgs (self .inputs .in_file )
97
- return region_signals
98
-
99
- def _4d_label_handler (self , nl , labels ):
100
- nmmasker = nl .NiftiMapsMasker (labels , detrend = self .inputs .detrend )
101
- nmmasker .fit ()
102
- region_signals = nmmasker .transform_single_imgs (self .inputs .in_file )
103
- return region_signals
77
+ def _process_inputs (self ):
78
+ ''' validate and process inputs into useful form '''
79
+
80
+ import nilearn .input_data as nl
81
+
82
+ # determine form of label files, choose appropriate nilearn masker
83
+ if len (self .inputs .label_files ) > 1 : # list of 3D nifti images
84
+ masker = nl .NiftiMapsMasker (self .inputs .label_files )
85
+ n_labels = len (self .inputs .label_files )
86
+ else : # list of size one, containing either a 3d or a 4d file
87
+ label_data = nb .load (self .inputs .label_files [0 ])
88
+ if len (label_data .shape ) == 4 : # 4d file
89
+ masker = nl .NiftiMapsMasker (label_data )
90
+ n_labels = label_data .shape [3 ]
91
+ else : # 3d file
92
+ if np .amax (label_data ) > 1 : # 3d label file
93
+ masker = nl .NiftiLabelsMasker (label_data )
94
+ # assuming consecutive positive integers for regions
95
+ n_labels = np .amax (label_data .get_data ())
96
+ else : # most probably a single probability map for one label
97
+ masker = nl .NiftiMapsMasker (label_data )
98
+ n_labels = 1
99
+
100
+ # check label list size
101
+ if len (self .inputs .class_labels ) != n_labels :
102
+ raise ValueError ('The length of class_labels {} does not '
103
+ 'match the number of regions {} found in '
104
+ 'label_files {}' .format (self .inputs .class_labels ,
105
+ n_labels ,
106
+ self .inputs .label_files ))
107
+
108
+ masker .set_params (detrend = self .inputs .detrend )
109
+ return masker
104
110
105
111
def _list_outputs (self ):
106
112
outputs = self ._outputs ().get ()
0 commit comments