Skip to content

Commit df7c093

Browse files
committed
fix: update getmask to dilate while binarizing
1 parent c835089 commit df7c093

File tree

1 file changed

+45
-49
lines changed
  • nipype/workflows/smri/freesurfer

1 file changed

+45
-49
lines changed

nipype/workflows/smri/freesurfer/utils.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
from nipype.workflows.misc.utils import region_list_from_volume, id_list_from_lookup_table
1414
import os, os.path as op
1515

16+
17+
def get_aparc_aseg(files):
18+
"""Return the aparc+aseg.mgz file"""
19+
for name in files:
20+
if 'aparc+aseg' in name:
21+
return name
22+
raise ValueError('aparc+aseg.mgz not found')
23+
24+
1625
def create_getmask_flow(name='getmask', dilate_mask=True):
1726
"""Registers a source file to freesurfer space and create a brain mask in
1827
source space
@@ -71,46 +80,40 @@ def create_getmask_flow(name='getmask', dilate_mask=True):
7180
"""
7281
Define all the nodes of the workflow:
7382
74-
fssource: used to retrieve aseg.mgz
75-
threshold : binarize aseg
76-
register : coregister source file to freesurfer space
77-
voltransform: convert binarized aparc+aseg to source file space
78-
83+
fssource: used to retrieve aseg.mgz
84+
threshold : binarize aseg
85+
register : coregister source file to freesurfer space
86+
voltransform: convert binarized aseg to source file space
7987
"""
8088

8189
fssource = pe.Node(nio.FreeSurferSource(),
82-
name = 'fssource')
90+
name = 'fssource')
8391
threshold = pe.Node(fs.Binarize(min=0.5, out_type='nii'),
84-
name='threshold')
92+
name='threshold')
8593
register = pe.MapNode(fs.BBRegister(init='fsl'),
86-
iterfield=['source_file'],
87-
name='register')
94+
iterfield=['source_file'],
95+
name='register')
8896
voltransform = pe.MapNode(fs.ApplyVolTransform(inverse=True),
89-
iterfield=['source_file', 'reg_file'],
90-
name='transform')
97+
iterfield=['source_file', 'reg_file'],
98+
name='transform')
9199

92100
"""
93101
Connect the nodes
94102
"""
95-
def get_aparc_aseg(files):
96-
for name in files:
97-
if 'aparc+aseg' in name:
98-
return name
99-
raise ValueError('aparc+aseg.mgz not found')
100103

101104
getmask.connect([
102-
(inputnode, fssource, [('subject_id','subject_id'),
103-
('subjects_dir','subjects_dir')]),
104-
(inputnode, register, [('source_file', 'source_file'),
105-
('subject_id', 'subject_id'),
106-
('subjects_dir', 'subjects_dir'),
107-
('contrast_type', 'contrast_type')]),
108-
(inputnode, voltransform, [('subjects_dir', 'subjects_dir'),
109-
('source_file', 'source_file')]),
110-
(fssource, threshold, [(('aparc_aseg', get_aparc_aseg), 'in_file')]),
111-
(register, voltransform, [('out_reg_file','reg_file')]),
112-
(threshold, voltransform, [('binary_file','target_file')])
113-
])
105+
(inputnode, fssource, [('subject_id','subject_id'),
106+
('subjects_dir','subjects_dir')]),
107+
(inputnode, register, [('source_file', 'source_file'),
108+
('subject_id', 'subject_id'),
109+
('subjects_dir', 'subjects_dir'),
110+
('contrast_type', 'contrast_type')]),
111+
(inputnode, voltransform, [('subjects_dir', 'subjects_dir'),
112+
('source_file', 'source_file')]),
113+
(fssource, threshold, [(('aparc_aseg', get_aparc_aseg), 'in_file')]),
114+
(register, voltransform, [('out_reg_file','reg_file')]),
115+
(threshold, voltransform, [('binary_file','target_file')])
116+
])
114117

115118

116119
"""
@@ -121,35 +124,28 @@ def get_aparc_aseg(files):
121124
"""
122125

123126
threshold2 = pe.MapNode(fs.Binarize(min=0.5, out_type='nii'),
124-
iterfield=['in_file'],
125-
name='threshold2')
127+
iterfield=['in_file'],
128+
name='threshold2')
126129
if dilate_mask:
127-
dilate = pe.MapNode(fsl.maths.DilateImage(operation='max'),
128-
iterfield=['in_file'],
129-
name='dilate')
130-
getmask.connect([
131-
(voltransform, dilate, [('transformed_file', 'in_file')]),
132-
(dilate, threshold2, [('out_file', 'in_file')]),
133-
])
134-
else:
135-
getmask.connect([
136-
(voltransform, threshold2, [('transformed_file', 'in_file')])
137-
])
130+
threshold2.inputs.dilate = 1
131+
getmask.connect([
132+
(voltransform, threshold2, [('transformed_file', 'in_file')])
133+
])
138134

139135
"""
140136
Setup an outputnode that defines relevant inputs of the workflow.
141137
"""
142138

143139
outputnode = pe.Node(niu.IdentityInterface(fields=["mask_file",
144-
"reg_file",
145-
"reg_cost"
146-
]),
147-
name="outputspec")
140+
"reg_file",
141+
"reg_cost"
142+
]),
143+
name="outputspec")
148144
getmask.connect([
149-
(register, outputnode, [("out_reg_file", "reg_file")]),
150-
(register, outputnode, [("min_cost_file", "reg_cost")]),
151-
(threshold2, outputnode, [("binary_file", "mask_file")]),
152-
])
145+
(register, outputnode, [("out_reg_file", "reg_file")]),
146+
(register, outputnode, [("min_cost_file", "reg_cost")]),
147+
(threshold2, outputnode, [("binary_file", "mask_file")]),
148+
])
153149
return getmask
154150

155151
def create_get_stats_flow(name='getstats', withreg=False):

0 commit comments

Comments
 (0)