Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 98 additions & 131 deletions nirodents/workflows/brainextraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def init_rodent_brain_extraction_wf(
Run an extra step to refine the brain mask using a brain-tissue segmentation with Atropos.

"""
wf = pe.Workflow(name)


if omp_nthreads is None or omp_nthreads < 1:
omp_nthreads = cpu_count()
Expand All @@ -61,8 +61,8 @@ def init_rodent_brain_extraction_wf(
name='inputnode')

# Find images in templateFlow
tpl_target_path = get_template(in_template, resolution=debug + 1, suffix="T2star" if bids_suffix == "T2w" else "T1w")
tpl_regmask_path = get_template(in_template, resolution=debug + 1, atlas='v3', desc='brain', suffix='mask')
tpl_target_path = get_template(in_template, resolution=debug + 1, suffix=tpl_suffix)
tpl_regmask_path = get_template(in_template, resolution=debug + 1, atlas=None, desc='brain', suffix='mask')
if tpl_regmask_path:
inputnode.inputs.in_mask = str(tpl_regmask_path)
tpl_tissue_labels = get_template(in_template,resolution=debug + 1, desc='cerebrum', suffix='dseg')
Expand Down Expand Up @@ -106,6 +106,10 @@ def init_rodent_brain_extraction_wf(
mrg_tmpl = pe.Node(niu.Merge(2), name='mrg_tmpl')
# mrg_tmpl.inputs.in1 = tpl_target_path

# Create integration nodes to allow compatibility between pipelines
integrate_1 = pe.Node(niu.IdentityInterface(fields=["in_file"]), name='integrate_1')
integrate_2 = pe.Node(niu.IdentityInterface(fields=["in_file"]), name='integrate_2')

# Initialize transforms with antsAI
init_aff = pe.Node(AI(
metric=('Mattes', 32, 'Regular', 0.5), #0.25
Expand All @@ -119,8 +123,8 @@ def init_rodent_brain_extraction_wf(
n_procs=omp_nthreads)

# Initial warping of template mask to subject space
warp_mask = pe.Node(ApplyTransforms(
interpolation='Linear', invert_transform_flags=True), name='warp_mask')
warp_mask_1 = pe.Node(ApplyTransforms(
interpolation='Linear', invert_transform_flags=True), name='warp_mask_1')

# Set up initial spatial normalization
init_settings_file = f'data/brainextraction_{init_normalization_quality}_{bids_suffix}.json'
Expand All @@ -144,9 +148,9 @@ def init_rodent_brain_extraction_wf(
mrg_init_transforms = pe.Node(niu.Merge(2), name='mrg_init_transforms')

# Use more precise transforms to warp mask to subject space
warp_mask_final = pe.Node(ApplyTransforms(
warp_mask_2 = pe.Node(ApplyTransforms(
interpolation='Linear', invert_transform_flags=[False, True]),
name='warp_mask_final')
name='warp_mask_2')

# morphological closing of warped mask
close_mask = pe.Node(MaskTool(outputtype='NIFTI_GZ', dilate_inputs='5 -5', fill_holes=True),
Expand All @@ -160,21 +164,23 @@ def init_rodent_brain_extraction_wf(

# Normalise skull-stripped image to brain template
final_settings_file = f'data/brainextraction_{final_normalization_quality}_{bids_suffix}.json'
final_norm = pe.Node(Registration(from_file=pkgr_fn(
refine_norm = pe.Node(Registration(from_file=pkgr_fn(
'nirodents', final_settings_file)),
name='final_norm',
name='refine_norm',
n_procs=omp_nthreads,
mem_gb=mem_gb)
final_norm.inputs.float = use_float
refine_norm.inputs.float = use_float

split_final_transforms = pe.Node(niu.Split(splits=[1, 1]), name='split_final_transforms')
mrg_final_transforms = pe.Node(niu.Merge(2), name='mrg_final_transforms')

warp_seg_mask = pe.Node(ApplyTransforms(
warp_mask_out = pe.Node(ApplyTransforms(
interpolation='Linear', invert_transform_flags=[False, True]),
name='warp_seg_mask')
name='warp_mask_out')
if tpl_brain_mask:
warp_seg_mask.inputs.input_image = tpl_brain_mask
warp_mask_out.inputs.input_image = tpl_brain_mask
else:
warp_mask_out.inputs.input_image = tpl_regmask_path

warp_seg_labels = pe.Node(ApplyTransforms(
interpolation='Linear', invert_transform_flags=[False, True]),
Expand All @@ -190,153 +196,114 @@ def init_rodent_brain_extraction_wf(

sinker = pe.Node(DataSink(), name='sinker')

#workflow definitions
#target image specific workflows
tar_prep = pe.Workflow('tar_prep')
if bids_suffix.lower() == 't2w':
wf.connect([
# resampling, truncation, initial N4, and creation of laplacian
tar_prep.connect([
# truncation, resampling, and initial N4
(inputnode, trunc, [('in_files', 'op1')]),
(trunc, res_target, [(('output_image', _pop), 'in_file')]),
(res_target, inu_n4, [('out_file', 'input_image')]),

# dilation of input mask
(inputnode, dil_mask, [('in_mask', 'in_file')]),

# ants AI inputs
(inu_n4, init_aff, [(('output_image', _pop), 'moving_image')]),
(dil_mask, init_aff, [('out_file', 'fixed_image_mask')]),
(res_tmpl, init_aff, [('out_file', 'fixed_image')]),

# warp mask to individual space
(dil_mask, warp_mask, [('out_file', 'input_image')]),
(trunc, warp_mask, [(('output_image', _pop), 'reference_image')]),
(init_aff, warp_mask, [('output_transform', 'transforms')]),
(inu_n4, integrate_1, [(('output_image', _pop), 'in_file')]),

# masked N4 correction
(trunc, inu_n4_final, [(('output_image', _pop), 'input_image')]),
(warp_mask, inu_n4_final, [('output_image', 'weight_image')]),
(inu_n4_final, integrate_2, [(('output_image', _pop), 'in_file')]),

# merge laplacian and original images
(inu_n4_final, lap_target, [(('output_image', _pop), 'op1')]),
(lap_target, norm_lap_target, [('output_image', 'op1')]),
(norm_lap_target, mrg_target, [('output_image', 'in2')]),
(inu_n4_final, res_target2, [(('output_image', _pop), 'in_file')]),
(res_target2, mrg_target, [('out_file', 'in1')]),

(res_tmpl, mrg_tmpl, [('out_file', 'in1')]),
(lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]),
(norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]),

# normalisation inputs
(init_aff, init_norm, [('output_transform', 'initial_moving_transform')]),
(warp_mask, init_norm, [('output_image', 'moving_image_masks')]),
(dil_mask, init_norm, [('out_file', 'fixed_image_masks')]),
(mrg_tmpl, init_norm, [('out', 'fixed_image')]),
(mrg_target, init_norm, [('out', 'moving_image')]),

# organise normalisation outputs to warp mask
(init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]),
(split_init_transforms, mrg_init_transforms, [('out2', 'in1')]),
(split_init_transforms, mrg_init_transforms, [('out1', 'in2')]),

(mrg_init_transforms, warp_mask_final, [('out', 'transforms')]),
(inu_n4_final, warp_mask_final, [(('output_image', _pop), 'reference_image')]),
(dil_mask, warp_mask_final, [('out_file', 'input_image')]),
(warp_mask_final, close_mask, [('output_image', 'in_file')]),
(close_mask, sinker, [('out_file', 'derivatives.@out_mask')]),

# mask brains
(inu_n4_final, skullstrip_tar, [(('output_image', _pop), 'in_file')]),
(close_mask, skullstrip_tar, [('out_file', 'in_mask')]),
(inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]),

# final_normalisation
(skullstrip_tpl, final_norm, [('out_file', 'fixed_image')]),
(skullstrip_tar, final_norm, [('out_file', 'moving_image')]),

# Warp mask and labels to subject-space
(final_norm, split_final_transforms, [('reverse_transforms', 'inlist')]),
(split_final_transforms, mrg_final_transforms, [('out2', 'in1')]),
(split_final_transforms, mrg_final_transforms, [('out1', 'in2')]),

(mrg_final_transforms, warp_seg_mask, [('out', 'transforms')]),
(skullstrip_tar, warp_seg_mask, [('out_file', 'reference_image')]),
(mrg_final_transforms, warp_seg_labels, [('out', 'transforms')]),
(skullstrip_tar, warp_seg_labels, [('out_file', 'reference_image')]),

# Segmentation
(skullstrip_tar, segment, [('out_file', 'intensity_images')]),
(warp_seg_labels, segment, [('output_image', 'prior_image')]),
(warp_seg_mask, segment, [('output_image', 'mask_image')])
])
return wf

elif bids_suffix == 't1w':
wf.connect([
# resampling and creation of laplacians
tar_prep.connect([
# resampling and laplacian; no truncation or N4
(inputnode, res_target, [('in_files', 'in_file')]),
(inputnode, lap_target, [('in_files', 'op1')]),
(lap_target, norm_lap_target, [('output_image', 'op1')]),
(norm_lap_target, mrg_target, [('output_image', 'in2')]),
(res_target, mrg_target, [('out_file', 'in1')]),
(res_target, integrate_1, [('out_file', 'in_file')]),
(inputnode, integrate_2, [('in_files', 'in_file')])
])

(res_tmpl, mrg_tmpl, [('out_file', 'in1')]),
(lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]),
(norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]),

#dilation of input mask
(inputnode, dil_mask, [('in_mask', 'in_file')]),

# ants AI inputs
(res_tmpl, init_aff, [('out_file', 'fixed_image')]),
(res_target, init_aff, [('out_file', 'moving_image')]),
(dil_mask, init_aff, [('out_file', 'fixed_image_mask')]),

# warp mask to individual space
(dil_mask, warp_mask, [('out_file', 'input_image')]),
(inputnode, warp_mask, [('in_files', 'reference_image')]),
(init_aff, warp_mask, [('output_transform', 'transforms')]),

# normalisation inputs
(mrg_tmpl, init_norm, [('out', 'fixed_image')]),
(mrg_target, init_norm, [('out', 'moving_image')]),
(dil_mask, init_norm, [('out_file', 'fixed_image_masks')]),
(warp_mask, init_norm, [('output_image', 'moving_image_masks')]),
(init_aff, init_norm, [('output_transform', 'initial_moving_transform')]),

#organise normalisation outputs to warp mask
(init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]),
(split_init_transforms, mrg_init_transforms, [('out2', 'in1')]),
(split_init_transforms, mrg_init_transforms, [('out1', 'in2')]),

(mrg_init_transforms, warp_mask_final, [('out', 'transforms')]),
(inputnode, warp_mask_final, [('in_files', 'reference_image')]),
(dil_mask, warp_mask_final, [('out_file', 'input_image')]),
(warp_mask_final, close_mask, [('output_image', 'in_file')]),

# mask brains
(inu_n4_final, skullstrip_tar, [(('output_image', _pop), 'in_file')]),
(close_mask, skullstrip_tar, [('out_file', 'in_mask')]),
(inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]),

# final_normalisation
(skullstrip_tpl, final_norm, [('out_file', 'fixed_image')]),
(skullstrip_tar, final_norm, [('out_file', 'moving_image')]),

# Warp mask and labels to subject-space
(final_norm, split_final_transforms, [('reverse_transforms', 'inlist')]),
(split_final_transforms, mrg_final_transforms, [('out2', 'in1')]),
(split_final_transforms, mrg_final_transforms, [('out1', 'in2')]),

(mrg_final_transforms, warp_seg_mask, [('out', 'transforms')]),
(skullstrip_tar, warp_seg_mask, [('out_file', 'reference_image')]),
#main workflow
wf = pe.Workflow(name)
wf.connect([
# template prep: dilation of input mask, resampling template, laplacian creation
(inputnode, dil_mask, [('in_mask', 'in_file')]),
(res_tmpl, mrg_tmpl, [('out_file', 'in1')]),
(lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]),
(norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]),

# ants AI inputs
(tar_prep, init_aff, [('integrate_1.out_file', 'moving_image')]),
(dil_mask, init_aff, [('out_file', 'fixed_image_mask')]),
(res_tmpl, init_aff, [('out_file', 'fixed_image')]),

# warp mask to individual space
(dil_mask, warp_mask_1, [('out_file', 'input_image')]),
(init_aff, warp_mask_1, [('output_transform', 'transforms')]),
(inputnode, warp_mask_1, [('in_files', 'reference_image')]),

# normalisation inputs
(init_aff, init_norm, [('output_transform', 'initial_moving_transform')]),
(warp_mask_1, init_norm, [('output_image', 'moving_image_masks')]),
(dil_mask, init_norm, [('out_file', 'fixed_image_masks')]),
(mrg_tmpl, init_norm, [('out', 'fixed_image')]),
(tar_prep, init_norm, [('mrg_target.out', 'moving_image')]),

#organise initial normalisation transforms for warps
(init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]),
(split_init_transforms, mrg_init_transforms, [('out2', 'in1')]),
(split_init_transforms, mrg_init_transforms, [('out1', 'in2')]),

# warp mask with initial normalisation transforms
(tar_prep, warp_mask_2, [('integrate_2.out_file', 'reference_image')]),
(dil_mask, warp_mask_2, [('out_file', 'input_image')]),
(mrg_init_transforms, warp_mask_2, [('out', 'transforms')]),
(warp_mask_2, close_mask, [('output_image', 'in_file')]),

# mask brains for refined normalisation
(tar_prep, skullstrip_tar, [('integrate_2.out_file', 'in_file')]),
(close_mask, skullstrip_tar, [('out_file', 'in_mask')]),
(inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]),

# refined normalisation
(skullstrip_tpl, refine_norm, [('out_file', 'fixed_image')]),
(skullstrip_tar, refine_norm, [('out_file', 'moving_image')]),

#organise refined normalisation transforms for warps
(refine_norm, split_final_transforms, [('reverse_transforms', 'inlist')]),
(split_final_transforms, mrg_final_transforms, [('out2', 'in1')]),
(split_final_transforms, mrg_final_transforms, [('out1', 'in2')]),

#warp mask to subject space and write out
(mrg_final_transforms, warp_mask_out, [('out', 'transforms')]),
(skullstrip_tar, warp_mask_out, [('out_file', 'reference_image')]),
(warp_mask_out, sinker, [('output_image', 'derivatives.@out_mask')]),
])
# add second target prep stage if necessary
if bids_suffix.lower() == 't2w':
wf.connect([(warp_mask_1, tar_prep, [('output_image', 'inu_n4_final.weight_image')])])

# add segmentation if necessary
if atropos_model:
wf.connect([
# Warp labels to subject-space
(mrg_final_transforms, warp_seg_labels, [('out', 'transforms')]),
(skullstrip_tar, warp_seg_labels, [('out_file', 'reference_image')]),

# Segmentation
(skullstrip_tar, segment, [('out_file', 'intensity_images')]),
(warp_seg_labels, segment, [('output_image', 'prior_image')]),
(warp_seg_mask, segment, [('output_image', 'mask_image')]),
(warp_mask_out, segment, [('output_image', 'mask_image')])
])
return wf

return wf

def _pop(in_files):
if isinstance(in_files, (list, tuple)):
Expand Down