diff --git a/nirodents/workflows/brainextraction.py b/nirodents/workflows/brainextraction.py index d60d00d..f680cc8 100644 --- a/nirodents/workflows/brainextraction.py +++ b/nirodents/workflows/brainextraction.py @@ -52,7 +52,7 @@ def init_rodent_brain_extraction_wf( Run an extra step to refine the brain mask using a brain-tissue segmentation with Atropos. """ - wf = pe.Workflow(name) + if omp_nthreads is None or omp_nthreads < 1: omp_nthreads = cpu_count() @@ -61,8 +61,8 @@ def init_rodent_brain_extraction_wf( name='inputnode') # Find images in templateFlow - tpl_target_path = get_template(in_template, resolution=debug + 1, suffix="T2star" if bids_suffix == "T2w" else "T1w") - tpl_regmask_path = get_template(in_template, resolution=debug + 1, atlas='v3', desc='brain', suffix='mask') + tpl_target_path = get_template(in_template, resolution=debug + 1, suffix=tpl_suffix) + tpl_regmask_path = get_template(in_template, resolution=debug + 1, atlas=None, desc='brain', suffix='mask') if tpl_regmask_path: inputnode.inputs.in_mask = str(tpl_regmask_path) tpl_tissue_labels = get_template(in_template,resolution=debug + 1, desc='cerebrum', suffix='dseg') @@ -106,6 +106,10 @@ def init_rodent_brain_extraction_wf( mrg_tmpl = pe.Node(niu.Merge(2), name='mrg_tmpl') # mrg_tmpl.inputs.in1 = tpl_target_path + # Create integration nodes to allow compatibility between pipelines + integrate_1 = pe.Node(niu.IdentityInterface(fields=["in_file"]), name='integrate_1') + integrate_2 = pe.Node(niu.IdentityInterface(fields=["in_file"]), name='integrate_2') + # Initialize transforms with antsAI init_aff = pe.Node(AI( metric=('Mattes', 32, 'Regular', 0.5), #0.25 @@ -119,8 +123,8 @@ def init_rodent_brain_extraction_wf( n_procs=omp_nthreads) # Initial warping of template mask to subject space - warp_mask = pe.Node(ApplyTransforms( - interpolation='Linear', invert_transform_flags=True), name='warp_mask') + warp_mask_1 = pe.Node(ApplyTransforms( + interpolation='Linear', invert_transform_flags=True), name='warp_mask_1') # Set up initial spatial normalization init_settings_file = f'data/brainextraction_{init_normalization_quality}_{bids_suffix}.json' @@ -144,9 +148,9 @@ def init_rodent_brain_extraction_wf( mrg_init_transforms = pe.Node(niu.Merge(2), name='mrg_init_transforms') # Use more precise transforms to warp mask to subject space - warp_mask_final = pe.Node(ApplyTransforms( + warp_mask_2 = pe.Node(ApplyTransforms( interpolation='Linear', invert_transform_flags=[False, True]), - name='warp_mask_final') + name='warp_mask_2') # morphological closing of warped mask close_mask = pe.Node(MaskTool(outputtype='NIFTI_GZ', dilate_inputs='5 -5', fill_holes=True), @@ -160,21 +164,23 @@ def init_rodent_brain_extraction_wf( # Normalise skull-stripped image to brain template final_settings_file = f'data/brainextraction_{final_normalization_quality}_{bids_suffix}.json' - final_norm = pe.Node(Registration(from_file=pkgr_fn( + refine_norm = pe.Node(Registration(from_file=pkgr_fn( 'nirodents', final_settings_file)), - name='final_norm', + name='refine_norm', n_procs=omp_nthreads, mem_gb=mem_gb) - final_norm.inputs.float = use_float + refine_norm.inputs.float = use_float split_final_transforms = pe.Node(niu.Split(splits=[1, 1]), name='split_final_transforms') mrg_final_transforms = pe.Node(niu.Merge(2), name='mrg_final_transforms') - warp_seg_mask = pe.Node(ApplyTransforms( + warp_mask_out = pe.Node(ApplyTransforms( interpolation='Linear', invert_transform_flags=[False, True]), - name='warp_seg_mask') + name='warp_mask_out') if tpl_brain_mask: - warp_seg_mask.inputs.input_image = tpl_brain_mask + warp_mask_out.inputs.input_image = tpl_brain_mask + else: + warp_mask_out.inputs.input_image = tpl_regmask_path warp_seg_labels = pe.Node(ApplyTransforms( interpolation='Linear', invert_transform_flags=[False, True]), @@ -190,29 +196,20 @@ def init_rodent_brain_extraction_wf( sinker = pe.Node(DataSink(), name='sinker') + #workflow definitions + #target image specific workflows + tar_prep = pe.Workflow('tar_prep') if bids_suffix.lower() == 't2w': - wf.connect([ - # resampling, truncation, initial N4, and creation of laplacian + tar_prep.connect([ + # truncation, resampling, and initial N4 (inputnode, trunc, [('in_files', 'op1')]), (trunc, res_target, [(('output_image', _pop), 'in_file')]), (res_target, inu_n4, [('out_file', 'input_image')]), - - # dilation of input mask - (inputnode, dil_mask, [('in_mask', 'in_file')]), - - # ants AI inputs - (inu_n4, init_aff, [(('output_image', _pop), 'moving_image')]), - (dil_mask, init_aff, [('out_file', 'fixed_image_mask')]), - (res_tmpl, init_aff, [('out_file', 'fixed_image')]), - - # warp mask to individual space - (dil_mask, warp_mask, [('out_file', 'input_image')]), - (trunc, warp_mask, [(('output_image', _pop), 'reference_image')]), - (init_aff, warp_mask, [('output_transform', 'transforms')]), + (inu_n4, integrate_1, [(('output_image', _pop), 'in_file')]), # masked N4 correction (trunc, inu_n4_final, [(('output_image', _pop), 'input_image')]), - (warp_mask, inu_n4_final, [('output_image', 'weight_image')]), + (inu_n4_final, integrate_2, [(('output_image', _pop), 'in_file')]), # merge laplacian and original images (inu_n4_final, lap_target, [(('output_image', _pop), 'op1')]), @@ -220,123 +217,93 @@ def init_rodent_brain_extraction_wf( (norm_lap_target, mrg_target, [('output_image', 'in2')]), (inu_n4_final, res_target2, [(('output_image', _pop), 'in_file')]), (res_target2, mrg_target, [('out_file', 'in1')]), - - (res_tmpl, mrg_tmpl, [('out_file', 'in1')]), - (lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]), - (norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]), - - # normalisation inputs - (init_aff, init_norm, [('output_transform', 'initial_moving_transform')]), - (warp_mask, init_norm, [('output_image', 'moving_image_masks')]), - (dil_mask, init_norm, [('out_file', 'fixed_image_masks')]), - (mrg_tmpl, init_norm, [('out', 'fixed_image')]), - (mrg_target, init_norm, [('out', 'moving_image')]), - - # organise normalisation outputs to warp mask - (init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]), - (split_init_transforms, mrg_init_transforms, [('out2', 'in1')]), - (split_init_transforms, mrg_init_transforms, [('out1', 'in2')]), - - (mrg_init_transforms, warp_mask_final, [('out', 'transforms')]), - (inu_n4_final, warp_mask_final, [(('output_image', _pop), 'reference_image')]), - (dil_mask, warp_mask_final, [('out_file', 'input_image')]), - (warp_mask_final, close_mask, [('output_image', 'in_file')]), - (close_mask, sinker, [('out_file', 'derivatives.@out_mask')]), - - # mask brains - (inu_n4_final, skullstrip_tar, [(('output_image', _pop), 'in_file')]), - (close_mask, skullstrip_tar, [('out_file', 'in_mask')]), - (inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]), - - # final_normalisation - (skullstrip_tpl, final_norm, [('out_file', 'fixed_image')]), - (skullstrip_tar, final_norm, [('out_file', 'moving_image')]), - - # Warp mask and labels to subject-space - (final_norm, split_final_transforms, [('reverse_transforms', 'inlist')]), - (split_final_transforms, mrg_final_transforms, [('out2', 'in1')]), - (split_final_transforms, mrg_final_transforms, [('out1', 'in2')]), - - (mrg_final_transforms, warp_seg_mask, [('out', 'transforms')]), - (skullstrip_tar, warp_seg_mask, [('out_file', 'reference_image')]), - (mrg_final_transforms, warp_seg_labels, [('out', 'transforms')]), - (skullstrip_tar, warp_seg_labels, [('out_file', 'reference_image')]), - - # Segmentation - (skullstrip_tar, segment, [('out_file', 'intensity_images')]), - (warp_seg_labels, segment, [('output_image', 'prior_image')]), - (warp_seg_mask, segment, [('output_image', 'mask_image')]) ]) - return wf - elif bids_suffix == 't1w': - wf.connect([ - # resampling and creation of laplacians + tar_prep.connect([ + # resampling and laplacian; no truncation or N4 (inputnode, res_target, [('in_files', 'in_file')]), (inputnode, lap_target, [('in_files', 'op1')]), (lap_target, norm_lap_target, [('output_image', 'op1')]), (norm_lap_target, mrg_target, [('output_image', 'in2')]), (res_target, mrg_target, [('out_file', 'in1')]), + (res_target, integrate_1, [('out_file', 'in_file')]), + (inputnode, integrate_2, [('in_files', 'in_file')]) + ]) - (res_tmpl, mrg_tmpl, [('out_file', 'in1')]), - (lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]), - (norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]), - - #dilation of input mask - (inputnode, dil_mask, [('in_mask', 'in_file')]), - - # ants AI inputs - (res_tmpl, init_aff, [('out_file', 'fixed_image')]), - (res_target, init_aff, [('out_file', 'moving_image')]), - (dil_mask, init_aff, [('out_file', 'fixed_image_mask')]), - - # warp mask to individual space - (dil_mask, warp_mask, [('out_file', 'input_image')]), - (inputnode, warp_mask, [('in_files', 'reference_image')]), - (init_aff, warp_mask, [('output_transform', 'transforms')]), - - # normalisation inputs - (mrg_tmpl, init_norm, [('out', 'fixed_image')]), - (mrg_target, init_norm, [('out', 'moving_image')]), - (dil_mask, init_norm, [('out_file', 'fixed_image_masks')]), - (warp_mask, init_norm, [('output_image', 'moving_image_masks')]), - (init_aff, init_norm, [('output_transform', 'initial_moving_transform')]), - - #organise normalisation outputs to warp mask - (init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]), - (split_init_transforms, mrg_init_transforms, [('out2', 'in1')]), - (split_init_transforms, mrg_init_transforms, [('out1', 'in2')]), - - (mrg_init_transforms, warp_mask_final, [('out', 'transforms')]), - (inputnode, warp_mask_final, [('in_files', 'reference_image')]), - (dil_mask, warp_mask_final, [('out_file', 'input_image')]), - (warp_mask_final, close_mask, [('output_image', 'in_file')]), - - # mask brains - (inu_n4_final, skullstrip_tar, [(('output_image', _pop), 'in_file')]), - (close_mask, skullstrip_tar, [('out_file', 'in_mask')]), - (inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]), - - # final_normalisation - (skullstrip_tpl, final_norm, [('out_file', 'fixed_image')]), - (skullstrip_tar, final_norm, [('out_file', 'moving_image')]), - - # Warp mask and labels to subject-space - (final_norm, split_final_transforms, [('reverse_transforms', 'inlist')]), - (split_final_transforms, mrg_final_transforms, [('out2', 'in1')]), - (split_final_transforms, mrg_final_transforms, [('out1', 'in2')]), - - (mrg_final_transforms, warp_seg_mask, [('out', 'transforms')]), - (skullstrip_tar, warp_seg_mask, [('out_file', 'reference_image')]), + #main workflow + wf = pe.Workflow(name) + wf.connect([ + # template prep: dilation of input mask, resampling template, laplacian creation + (inputnode, dil_mask, [('in_mask', 'in_file')]), + (res_tmpl, mrg_tmpl, [('out_file', 'in1')]), + (lap_tmpl, norm_lap_tmpl, [('output_image', 'op1')]), + (norm_lap_tmpl, mrg_tmpl, [('output_image', 'in2')]), + + # ants AI inputs + (tar_prep, init_aff, [('integrate_1.out_file', 'moving_image')]), + (dil_mask, init_aff, [('out_file', 'fixed_image_mask')]), + (res_tmpl, init_aff, [('out_file', 'fixed_image')]), + + # warp mask to individual space + (dil_mask, warp_mask_1, [('out_file', 'input_image')]), + (init_aff, warp_mask_1, [('output_transform', 'transforms')]), + (inputnode, warp_mask_1, [('in_files', 'reference_image')]), + + # normalisation inputs + (init_aff, init_norm, [('output_transform', 'initial_moving_transform')]), + (warp_mask_1, init_norm, [('output_image', 'moving_image_masks')]), + (dil_mask, init_norm, [('out_file', 'fixed_image_masks')]), + (mrg_tmpl, init_norm, [('out', 'fixed_image')]), + (tar_prep, init_norm, [('mrg_target.out', 'moving_image')]), + + #organise initial normalisation transforms for warps + (init_norm, split_init_transforms, [('reverse_transforms', 'inlist')]), + (split_init_transforms, mrg_init_transforms, [('out2', 'in1')]), + (split_init_transforms, mrg_init_transforms, [('out1', 'in2')]), + + # warp mask with initial normalisation transforms + (tar_prep, warp_mask_2, [('integrate_2.out_file', 'reference_image')]), + (dil_mask, warp_mask_2, [('out_file', 'input_image')]), + (mrg_init_transforms, warp_mask_2, [('out', 'transforms')]), + (warp_mask_2, close_mask, [('output_image', 'in_file')]), + + # mask brains for refined normalisation + (tar_prep, skullstrip_tar, [('integrate_2.out_file', 'in_file')]), + (close_mask, skullstrip_tar, [('out_file', 'in_mask')]), + (inputnode, skullstrip_tpl, [('in_mask', 'in_mask')]), + + # refined normalisation + (skullstrip_tpl, refine_norm, [('out_file', 'fixed_image')]), + (skullstrip_tar, refine_norm, [('out_file', 'moving_image')]), + + #organise refined normalisation transforms for warps + (refine_norm, split_final_transforms, [('reverse_transforms', 'inlist')]), + (split_final_transforms, mrg_final_transforms, [('out2', 'in1')]), + (split_final_transforms, mrg_final_transforms, [('out1', 'in2')]), + + #warp mask to subject space and write out + (mrg_final_transforms, warp_mask_out, [('out', 'transforms')]), + (skullstrip_tar, warp_mask_out, [('out_file', 'reference_image')]), + (warp_mask_out, sinker, [('output_image', 'derivatives.@out_mask')]), + ]) + # add second target prep stage if necessary + if bids_suffix.lower() == 't2w': + wf.connect([(warp_mask_1, tar_prep, [('output_image', 'inu_n4_final.weight_image')])]) + + # add segmentation if necessary + if atropos_model: + wf.connect([ + # Warp labels to subject-space (mrg_final_transforms, warp_seg_labels, [('out', 'transforms')]), (skullstrip_tar, warp_seg_labels, [('out_file', 'reference_image')]), # Segmentation (skullstrip_tar, segment, [('out_file', 'intensity_images')]), (warp_seg_labels, segment, [('output_image', 'prior_image')]), - (warp_seg_mask, segment, [('output_image', 'mask_image')]), + (warp_mask_out, segment, [('output_image', 'mask_image')]) ]) - return wf + + return wf def _pop(in_files): if isinstance(in_files, (list, tuple)):