3636from nibabies import config
3737from nibabies .workflows .anatomical .brain_extraction import init_infant_brain_extraction_wf
3838from nibabies .workflows .anatomical .outputs import init_anat_reports_wf
39- from nibabies .workflows .anatomical .preproc import init_anat_preproc_wf
39+ from nibabies .workflows .anatomical .preproc import init_anat_preproc_wf , init_csf_norm_wf
4040from nibabies .workflows .anatomical .registration import init_coregistration_wf
4141from nibabies .workflows .anatomical .segmentation import init_segmentation_wf
4242from nibabies .workflows .anatomical .surfaces import init_mcribs_dhcp_wf
@@ -184,6 +184,14 @@ def init_infant_anat_fit_wf(
184184 name = 'anat_buffer' ,
185185 )
186186
187+ # Additional buffer if CSF normalization is used
188+ anat_preproc_buffer = pe .Node (
189+ niu .IdentityInterface (fields = ['anat_preproc' ]),
190+ name = 'anat_preproc_buffer' ,
191+ )
192+ if not config .workflow .norm_csf :
193+ workflow .connect (anat_buffer , 'anat_preproc' , anat_preproc_buffer , 'anat_preproc' )
194+
187195 if reference_anat == 'T1w' :
188196 LOGGER .info ('ANAT: Using T1w as the reference anatomical' )
189197 workflow .connect ([
@@ -248,7 +256,7 @@ def init_infant_anat_fit_wf(
248256 msm_buffer = pe .Node (niu .IdentityInterface (fields = ['sphere_reg_msm' ]), name = 'msm_buffer' )
249257
250258 workflow .connect ([
251- (anat_buffer , outputnode , [
259+ (anat_preproc_buffer , outputnode , [
252260 ('anat_preproc' , 'anat_preproc' ),
253261 ]),
254262 (refined_buffer , outputnode , [
@@ -637,24 +645,6 @@ def init_infant_anat_fit_wf(
637645 (binarize_t2w , t2w_buffer , [('out_file' , 't2w_mask' )]),
638646 ]) # fmt:skip
639647 else :
640- # Check whether we can convert a previously computed T2w mask
641- # or need to run the atlas based brain extraction
642-
643- # if t1w_mask:
644- # LOGGER.info('ANAT T1w mask will be transformed into T2w space')
645- # transform_t1w_mask = pe.Node(
646- # ApplyTransforms(interpolation='MultiLabel'),
647- # name='transform_t1w_mask',
648- # )
649-
650- # workflow.connect([
651- # (t1w_buffer, transform_t1w_mask, [('t1w_mask', 'input_image')]),
652- # (coreg_buffer, transform_t1w_mask, [('t1w2t2w_xfm', 'transforms')]),
653- # (transform_t1w_mask, apply_t2w_mask, [('output_image', 'in_mask')]),
654- # (t2w_buffer, apply_t1w_mask, [('t2w_preproc', 'in_file')]),
655- # # TODO: Unsure about this connection^
656- # ]) # fmt:skip
657- # else:
658648 LOGGER .info ('ANAT Atlas-based brain mask will be calculated on the T2w' )
659649 brain_extraction_wf = init_infant_brain_extraction_wf (
660650 omp_nthreads = omp_nthreads ,
@@ -898,6 +888,15 @@ def init_infant_anat_fit_wf(
898888 anat2std_buffer .inputs .in1 = [xfm ['forward' ] for xfm in found_xfms .values ()]
899889 std2anat_buffer .inputs .in1 = [xfm ['reverse' ] for xfm in found_xfms .values ()]
900890
891+ if config .workflow .norm_csf :
892+ csf_norm_wf = init_csf_norm_wf ()
893+
894+ workflow .connect ([
895+ (anat_buffer , csf_norm_wf , [('anat_preproc' , 'inputnode.anat_preproc' )]),
896+ (seg_buffer , csf_norm_wf , [('anat_tpms' , 'inputnode.anat_tpms' )]),
897+ (csf_norm_wf , anat_preproc_buffer , [('outputnode.anat_preproc' , 'anat_preproc' )]),
898+ ]) # fmt:skip
899+
901900 if templates :
902901 LOGGER .info (f'ANAT Stage 5: Preparing normalization workflow for { templates } ' )
903902 register_template_wf = init_register_template_wf (
@@ -913,7 +912,9 @@ def init_infant_anat_fit_wf(
913912
914913 workflow .connect ([
915914 (inputnode , register_template_wf , [('roi' , 'inputnode.lesion_mask' )]),
916- (anat_buffer , register_template_wf , [('anat_preproc' , 'inputnode.moving_image' )]),
915+ (anat_preproc_buffer , register_template_wf , [
916+ ('anat_preproc' , 'inputnode.moving_image' ),
917+ ]),
917918 (refined_buffer , register_template_wf , [('anat_mask' , 'inputnode.moving_mask' )]),
918919 (sourcefile_buffer , ds_template_registration_wf , [
919920 ('anat_source_files' , 'inputnode.source_files' )
@@ -1106,7 +1107,7 @@ def init_infant_anat_fit_wf(
11061107 (seg_buffer , refinement_wf , [
11071108 ('ants_segs' , 'inputnode.ants_segs' ), # TODO: Verify this is the same as dseg
11081109 ]),
1109- (anat_buffer , applyrefined , [('anat_preproc' , 'in_file' )]),
1110+ (anat_preproc_buffer , applyrefined , [('anat_preproc' , 'in_file' )]),
11101111 (refinement_wf , applyrefined , [('outputnode.out_brainmask' , 'in_mask' )]),
11111112 (refinement_wf , refined_buffer , [('outputnode.out_brainmask' , 'anat_mask' )]),
11121113 (applyrefined , refined_buffer , [('out_file' , 'anat_brain' )]),
@@ -1384,6 +1385,14 @@ def init_infant_single_anat_fit_wf(
13841385 name = 'anat_buffer' ,
13851386 )
13861387
1388+ # Additional buffer if CSF normalization is used
1389+ anat_preproc_buffer = pe .Node (
1390+ niu .IdentityInterface (fields = ['anat_preproc' ]),
1391+ name = 'anat_preproc_buffer' ,
1392+ )
1393+ if not config .workflow .norm_csf :
1394+ workflow .connect (anat_buffer , 'anat_preproc' , anat_preproc_buffer , 'anat_preproc' )
1395+
13871396 aseg_buffer = pe .Node (
13881397 niu .IdentityInterface (fields = ['anat_aseg' ]),
13891398 name = 'aseg_buffer' ,
@@ -1423,7 +1432,7 @@ def init_infant_single_anat_fit_wf(
14231432 msm_buffer = pe .Node (niu .IdentityInterface (fields = ['sphere_reg_msm' ]), name = 'msm_buffer' )
14241433
14251434 workflow .connect ([
1426- (anat_buffer , outputnode , [
1435+ (anat_preproc_buffer , outputnode , [
14271436 ('anat_preproc' , 'anat_preproc' ),
14281437 ]),
14291438 (refined_buffer , outputnode , [
@@ -1724,6 +1733,15 @@ def init_infant_single_anat_fit_wf(
17241733 anat2std_buffer .inputs .in1 = [xfm ['forward' ] for xfm in found_xfms .values ()]
17251734 std2anat_buffer .inputs .in1 = [xfm ['reverse' ] for xfm in found_xfms .values ()]
17261735
1736+ if config .workflow .norm_csf :
1737+ csf_norm_wf = init_csf_norm_wf ()
1738+
1739+ workflow .connect ([
1740+ (anat_buffer , csf_norm_wf , [('anat_preproc' , 'inputnode.anat_preproc' )]),
1741+ (seg_buffer , csf_norm_wf , [('anat_tpms' , 'inputnode.anat_tpms' )]),
1742+ (csf_norm_wf , anat_preproc_buffer , [('outputnode.anat_preproc' , 'anat_preproc' )]),
1743+ ]) # fmt:skip
1744+
17271745 if templates :
17281746 LOGGER .info (f'ANAT Stage 4: Preparing normalization workflow for { templates } ' )
17291747 register_template_wf = init_register_template_wf (
@@ -1739,7 +1757,9 @@ def init_infant_single_anat_fit_wf(
17391757
17401758 workflow .connect ([
17411759 (inputnode , register_template_wf , [('roi' , 'inputnode.lesion_mask' )]),
1742- (anat_buffer , register_template_wf , [('anat_preproc' , 'inputnode.moving_image' )]),
1760+ (anat_preproc_buffer , register_template_wf , [
1761+ ('anat_preproc' , 'inputnode.moving_image' ),
1762+ ]),
17431763 (refined_buffer , register_template_wf , [('anat_mask' , 'inputnode.moving_mask' )]),
17441764 (sourcefile_buffer , ds_template_registration_wf , [
17451765 ('anat_source_files' , 'inputnode.source_files' )
@@ -1921,7 +1941,7 @@ def init_infant_single_anat_fit_wf(
19211941 (seg_buffer , refinement_wf , [
19221942 ('ants_segs' , 'inputnode.ants_segs' ),
19231943 ]),
1924- (anat_buffer , applyrefined , [('anat_preproc' , 'in_file' )]),
1944+ (anat_preproc_buffer , applyrefined , [('anat_preproc' , 'in_file' )]),
19251945 (refinement_wf , applyrefined , [('outputnode.out_brainmask' , 'in_mask' )]),
19261946 (refinement_wf , refined_buffer , [('outputnode.out_brainmask' , 'anat_mask' )]),
19271947 (applyrefined , refined_buffer , [('out_file' , 'anat_brain' )]),
0 commit comments