Skip to content

Commit 658debc

Browse files
eilidhmacnicoloesteban
authored andcommitted
MAIN: workflow tidying
1 parent 24c667a commit 658debc

File tree

1 file changed

+96
-129
lines changed

1 file changed

+96
-129
lines changed

nirodents/workflows/brainextraction.py

Lines changed: 96 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def init_rodent_brain_extraction_wf(
5252
Run an extra step to refine the brain mask using a brain-tissue segmentation with Atropos.
5353
5454
"""
55-
wf = pe.Workflow(name)
55+
5656

5757
if omp_nthreads is None or omp_nthreads < 1:
5858
omp_nthreads = cpu_count()
@@ -106,6 +106,10 @@ def init_rodent_brain_extraction_wf(
106106
mrg_tmpl = pe.Node(niu.Merge(2), name='mrg_tmpl')
107107
# mrg_tmpl.inputs.in1 = tpl_target_path
108108

109+
# Create integration nodes to allow compatibility between pipelines
110+
integrate_1 = pe.Node(niu.IdentityInterface(fields=["in_file"]), name='integrate_1')
111+
integrate_2 = pe.Node(niu.IdentityInterface(fields=["in_file"]), name='integrate_2')
112+
109113
# Initialize transforms with antsAI
110114
init_aff = pe.Node(AI(
111115
metric=('Mattes', 32, 'Regular', 0.5), #0.25
@@ -119,8 +123,8 @@ def init_rodent_brain_extraction_wf(
119123
n_procs=omp_nthreads)
120124

121125
# Initial warping of template mask to subject space
122-
warp_mask = pe.Node(ApplyTransforms(
123-
interpolation='Linear', invert_transform_flags=True), name='warp_mask')
126+
warp_mask_1 = pe.Node(ApplyTransforms(
127+
interpolation='Linear', invert_transform_flags=True), name='warp_mask_1')
124128

125129
# Set up initial spatial normalization
126130
init_settings_file = f'data/brainextraction_{init_normalization_quality}_{bids_suffix}.json'
@@ -144,9 +148,9 @@ def init_rodent_brain_extraction_wf(
144148
mrg_init_transforms = pe.Node(niu.Merge(2), name='mrg_init_transforms')
145149

146150
# Use more precise transforms to warp mask to subject space
147-
warp_mask_final = pe.Node(ApplyTransforms(
151+
warp_mask_2 = pe.Node(ApplyTransforms(
148152
interpolation='Linear', invert_transform_flags=[False, True]),
149-
name='warp_mask_final')
153+
name='warp_mask_2')
150154

151155
# morphological closing of warped mask
152156
close_mask = pe.Node(MaskTool(outputtype='NIFTI_GZ', dilate_inputs='5 -5', fill_holes=True),
@@ -160,21 +164,23 @@ def init_rodent_brain_extraction_wf(
160164

161165
# Normalise skull-stripped image to brain template
162166
final_settings_file = f'data/brainextraction_{final_normalization_quality}_{bids_suffix}.json'
163-
final_norm = pe.Node(Registration(from_file=pkgr_fn(
167+
refine_norm = pe.Node(Registration(from_file=pkgr_fn(
164168
'nirodents', final_settings_file)),
165-
name='final_norm',
169+
name='refine_norm',
166170
n_procs=omp_nthreads,
167171
mem_gb=mem_gb)
168-
final_norm.inputs.float = use_float
172+
refine_norm.inputs.float = use_float
169173

170174
split_final_transforms = pe.Node(niu.Split(splits=[1, 1]), name='split_final_transforms')
171175
mrg_final_transforms = pe.Node(niu.Merge(2), name='mrg_final_transforms')
172176

173-
warp_seg_mask = pe.Node(ApplyTransforms(
177+
warp_mask_out = pe.Node(ApplyTransforms(
174178
interpolation='Linear', invert_transform_flags=[False, True]),
175-
name='warp_seg_mask')
179+
name='warp_mask_out')
176180
if tpl_brain_mask:
177-
warp_seg_mask.inputs.input_image = tpl_brain_mask
181+
warp_mask_out.inputs.input_image = tpl_brain_mask
182+
else:
183+
warp_mask_out.inputs.input_image = tpl_regmask_path
178184

179185
warp_seg_labels = pe.Node(ApplyTransforms(
180186
interpolation='Linear', invert_transform_flags=[False, True]),
@@ -190,153 +196,114 @@ def init_rodent_brain_extraction_wf(
190196

191197
sinker = pe.Node(DataSink(), name='sinker')
192198

199+
#workflow definitions
200+
#target image specific workflows
201+
tar_prep = pe.Workflow('tar_prep')
193202
if bids_suffix.lower() == 't2w':
194-
wf.connect([
195-
# resampling, truncation, initial N4, and creation of laplacian
203+
tar_prep.connect([
204+
# truncation, resampling, and initial N4
196205
(inputnode, trunc, [('in_files', 'op1')]),
197206
(trunc, res_target, [(('output_image', _pop), 'in_file')]),
198207
(res_target, inu_n4, [('out_file', 'input_image')]),
199-
200-
# dilation of input mask
201-
(inputnode, dil_mask, [('in_mask', 'in_file')]),
202-
203-
# ants AI inputs
204-
(inu_n4, init_aff, [(('output_image', _pop), 'moving_image')]),
205-
(dil_mask, init_aff, [('out_file', 'fixed_image_mask')]),
206-
(res_tmpl, init_aff, [('out_file', 'fixed_image')]),
207-
208-
# warp mask to individual space
209-
(dil_mask, warp_mask, [('out_file', 'input_image')]),
210-
(trunc, warp_mask, [(('output_image', _pop), 'reference_image')]),
211-
(init_aff, warp_mask, [('output_transform', 'transforms')]),
208+
(inu_n4, integrate_1, [(('output_image', _pop), 'in_file')]),
212209

213210
# masked N4 correction
214211
(trunc, inu_n4_final, [(('output_image', _pop), 'input_image')]),
215-
(warp_mask, inu_n4_final, [('output_image', 'weight_image')]),
212+
(inu_n4_final, integrate_2, [(('output_image', _pop), 'in_file')]),
216213

217214
# merge laplacian and original images
218215
(inu_n4_final, lap_target, [(('output_image', _pop), 'op1')]),
219216
(lap_target, norm_lap_target, [('output_image', 'op1')]),
220217
(norm_lap_target, mrg_target, [('output_image', 'in2')]),
221218
(inu_n4_final, res_target2, [(('output_image', _pop), 'in_file')]),
222219
(res_target2, mrg_target, [('out_file', 'in1')]),
223-
224-
(res_tmpl, mrg_tmpl, [('out_file', 'in1')]),
225-
(lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]),
226-
(norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]),
227-
228-
# normalisation inputs
229-
(init_aff, init_norm, [('output_transform', 'initial_moving_transform')]),
230-
(warp_mask, init_norm, [('output_image', 'moving_image_masks')]),
231-
(dil_mask, init_norm, [('out_file', 'fixed_image_masks')]),
232-
(mrg_tmpl, init_norm, [('out', 'fixed_image')]),
233-
(mrg_target, init_norm, [('out', 'moving_image')]),
234-
235-
# organise normalisation outputs to warp mask
236-
(init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]),
237-
(split_init_transforms, mrg_init_transforms, [('out2', 'in1')]),
238-
(split_init_transforms, mrg_init_transforms, [('out1', 'in2')]),
239-
240-
(mrg_init_transforms, warp_mask_final, [('out', 'transforms')]),
241-
(inu_n4_final, warp_mask_final, [(('output_image', _pop), 'reference_image')]),
242-
(dil_mask, warp_mask_final, [('out_file', 'input_image')]),
243-
(warp_mask_final, close_mask, [('output_image', 'in_file')]),
244-
(close_mask, sinker, [('out_file', 'derivatives.@out_mask')]),
245-
246-
# mask brains
247-
(inu_n4_final, skullstrip_tar, [(('output_image', _pop), 'in_file')]),
248-
(close_mask, skullstrip_tar, [('out_file', 'in_mask')]),
249-
(inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]),
250-
251-
# final_normalisation
252-
(skullstrip_tpl, final_norm, [('out_file', 'fixed_image')]),
253-
(skullstrip_tar, final_norm, [('out_file', 'moving_image')]),
254-
255-
# Warp mask and labels to subject-space
256-
(final_norm, split_final_transforms, [('reverse_transforms', 'inlist')]),
257-
(split_final_transforms, mrg_final_transforms, [('out2', 'in1')]),
258-
(split_final_transforms, mrg_final_transforms, [('out1', 'in2')]),
259-
260-
(mrg_final_transforms, warp_seg_mask, [('out', 'transforms')]),
261-
(skullstrip_tar, warp_seg_mask, [('out_file', 'reference_image')]),
262-
(mrg_final_transforms, warp_seg_labels, [('out', 'transforms')]),
263-
(skullstrip_tar, warp_seg_labels, [('out_file', 'reference_image')]),
264-
265-
# Segmentation
266-
(skullstrip_tar, segment, [('out_file', 'intensity_images')]),
267-
(warp_seg_labels, segment, [('output_image', 'prior_image')]),
268-
(warp_seg_mask, segment, [('output_image', 'mask_image')])
269220
])
270-
return wf
271-
272221
elif bids_suffix == 't1w':
273-
wf.connect([
274-
# resampling and creation of laplacians
222+
tar_prep.connect([
223+
# resampling and laplacian; no truncation or N4
275224
(inputnode, res_target, [('in_files', 'in_file')]),
276225
(inputnode, lap_target, [('in_files', 'op1')]),
277226
(lap_target, norm_lap_target, [('output_image', 'op1')]),
278227
(norm_lap_target, mrg_target, [('output_image', 'in2')]),
279228
(res_target, mrg_target, [('out_file', 'in1')]),
229+
(res_target, integrate_1, [('out_file', 'in_file')]),
230+
(inputnode, integrate_2, [('in_files', 'in_file')])
231+
])
280232

281-
(res_tmpl, mrg_tmpl, [('out_file', 'in1')]),
282-
(lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]),
283-
(norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]),
284-
285-
#dilation of input mask
286-
(inputnode, dil_mask, [('in_mask', 'in_file')]),
287-
288-
# ants AI inputs
289-
(res_tmpl, init_aff, [('out_file', 'fixed_image')]),
290-
(res_target, init_aff, [('out_file', 'moving_image')]),
291-
(dil_mask, init_aff, [('out_file', 'fixed_image_mask')]),
292-
293-
# warp mask to individual space
294-
(dil_mask, warp_mask, [('out_file', 'input_image')]),
295-
(inputnode, warp_mask, [('in_files', 'reference_image')]),
296-
(init_aff, warp_mask, [('output_transform', 'transforms')]),
297-
298-
# normalisation inputs
299-
(mrg_tmpl, init_norm, [('out', 'fixed_image')]),
300-
(mrg_target, init_norm, [('out', 'moving_image')]),
301-
(dil_mask, init_norm, [('out_file', 'fixed_image_masks')]),
302-
(warp_mask, init_norm, [('output_image', 'moving_image_masks')]),
303-
(init_aff, init_norm, [('output_transform', 'initial_moving_transform')]),
304-
305-
#organise normalisation outputs to warp mask
306-
(init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]),
307-
(split_init_transforms, mrg_init_transforms, [('out2', 'in1')]),
308-
(split_init_transforms, mrg_init_transforms, [('out1', 'in2')]),
309-
310-
(mrg_init_transforms, warp_mask_final, [('out', 'transforms')]),
311-
(inputnode, warp_mask_final, [('in_files', 'reference_image')]),
312-
(dil_mask, warp_mask_final, [('out_file', 'input_image')]),
313-
(warp_mask_final, close_mask, [('output_image', 'in_file')]),
314-
315-
# mask brains
316-
(inu_n4_final, skullstrip_tar, [(('output_image', _pop), 'in_file')]),
317-
(close_mask, skullstrip_tar, [('out_file', 'in_mask')]),
318-
(inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]),
319-
320-
# final_normalisation
321-
(skullstrip_tpl, final_norm, [('out_file', 'fixed_image')]),
322-
(skullstrip_tar, final_norm, [('out_file', 'moving_image')]),
323-
324-
# Warp mask and labels to subject-space
325-
(final_norm, split_final_transforms, [('reverse_transforms', 'inlist')]),
326-
(split_final_transforms, mrg_final_transforms, [('out2', 'in1')]),
327-
(split_final_transforms, mrg_final_transforms, [('out1', 'in2')]),
328-
329-
(mrg_final_transforms, warp_seg_mask, [('out', 'transforms')]),
330-
(skullstrip_tar, warp_seg_mask, [('out_file', 'reference_image')]),
233+
#main workflow
234+
wf = pe.Workflow(name)
235+
wf.connect([
236+
# template prep: dilation of input mask, resampling template, laplacian creation
237+
(inputnode, dil_mask, [('in_mask', 'in_file')]),
238+
(res_tmpl, mrg_tmpl, [('out_file', 'in1')]),
239+
(lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]),
240+
(norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]),
241+
242+
# ants AI inputs
243+
(tar_prep, init_aff, [('integrate_1.out_file', 'moving_image')]),
244+
(dil_mask, init_aff, [('out_file', 'fixed_image_mask')]),
245+
(res_tmpl, init_aff, [('out_file', 'fixed_image')]),
246+
247+
# warp mask to individual space
248+
(dil_mask, warp_mask_1, [('out_file', 'input_image')]),
249+
(init_aff, warp_mask_1, [('output_transform', 'transforms')]),
250+
(inputnode, warp_mask_1, [('in_files', 'reference_image')]),
251+
252+
# normalisation inputs
253+
(init_aff, init_norm, [('output_transform', 'initial_moving_transform')]),
254+
(warp_mask_1, init_norm, [('output_image', 'moving_image_masks')]),
255+
(dil_mask, init_norm, [('out_file', 'fixed_image_masks')]),
256+
(mrg_tmpl, init_norm, [('out', 'fixed_image')]),
257+
(tar_prep, init_norm, [('mrg_target.out', 'moving_image')]),
258+
259+
#organise initial normalisation transforms for warps
260+
(init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]),
261+
(split_init_transforms, mrg_init_transforms, [('out2', 'in1')]),
262+
(split_init_transforms, mrg_init_transforms, [('out1', 'in2')]),
263+
264+
# warp mask with initial normalisation transforms
265+
(tar_prep, warp_mask_2, [('integrate_2.out_file', 'reference_image')]),
266+
(dil_mask, warp_mask_2, [('out_file', 'input_image')]),
267+
(mrg_init_transforms, warp_mask_2, [('out', 'transforms')]),
268+
(warp_mask_2, close_mask, [('output_image', 'in_file')]),
269+
270+
# mask brains for refined normalisation
271+
(tar_prep, skullstrip_tar, [('integrate_2.out_file', 'in_file')]),
272+
(close_mask, skullstrip_tar, [('out_file', 'in_mask')]),
273+
(inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]),
274+
275+
# refined normalisation
276+
(skullstrip_tpl, refine_norm, [('out_file', 'fixed_image')]),
277+
(skullstrip_tar, refine_norm, [('out_file', 'moving_image')]),
278+
279+
#organise refined normalisation transforms for warps
280+
(refine_norm, split_final_transforms, [('reverse_transforms', 'inlist')]),
281+
(split_final_transforms, mrg_final_transforms, [('out2', 'in1')]),
282+
(split_final_transforms, mrg_final_transforms, [('out1', 'in2')]),
283+
284+
#warp mask to subject space and write out
285+
(mrg_final_transforms, warp_mask_out, [('out', 'transforms')]),
286+
(skullstrip_tar, warp_mask_out, [('out_file', 'reference_image')]),
287+
(warp_mask_out, sinker, [('output_image', 'derivatives.@out_mask')]),
288+
])
289+
# add second target prep stage if necessary
290+
if bids_suffix.lower() == 't2w':
291+
wf.connect([(warp_mask_1, tar_prep, [('output_image', 'inu_n4_final.weight_image')])])
292+
293+
# add segmentation if necessary
294+
if atropos_model:
295+
wf.connect([
296+
# Warp labels to subject-space
331297
(mrg_final_transforms, warp_seg_labels, [('out', 'transforms')]),
332298
(skullstrip_tar, warp_seg_labels, [('out_file', 'reference_image')]),
333299

334300
# Segmentation
335301
(skullstrip_tar, segment, [('out_file', 'intensity_images')]),
336302
(warp_seg_labels, segment, [('output_image', 'prior_image')]),
337-
(warp_seg_mask, segment, [('output_image', 'mask_image')]),
303+
(warp_mask_out, segment, [('output_image', 'mask_image')])
338304
])
339-
return wf
305+
306+
return wf
340307

341308
def _pop(in_files):
342309
if isinstance(in_files, (list, tuple)):

0 commit comments

Comments
 (0)