Skip to content

Commit 3d4da90

Browse files
committed
FIX: Add buffer for CSF normalization
1 parent 8548ad3 commit 3d4da90

File tree

1 file changed

+42
-10
lines changed
  • nibabies/workflows/anatomical

1 file changed

+42
-10
lines changed

nibabies/workflows/anatomical/fit.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from nibabies import config
3737
from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf
3838
from nibabies.workflows.anatomical.outputs import init_anat_reports_wf
39-
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf
39+
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf, init_csf_norm_wf
4040
from nibabies.workflows.anatomical.registration import init_coregistration_wf
4141
from nibabies.workflows.anatomical.segmentation import init_segmentation_wf
4242
from nibabies.workflows.anatomical.surfaces import init_mcribs_dhcp_wf
@@ -184,11 +184,13 @@ def init_infant_anat_fit_wf(
184184
name='anat_buffer',
185185
)
186186

187-
# Additional CSF normalization, if necessary
188-
anat_norm_buffer = pe.Node(
187+
# Additional buffer if CSF normalization is used
188+
anat_preproc_buffer = pe.Node(
189189
niu.IdentityInterface(fields=['anat_preproc']),
190-
name='anat_norm_buffer',
190+
name='anat_preproc_buffer',
191191
)
192+
if not config.workflow.norm_csf:
193+
workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc')
192194

193195
if reference_anat == 'T1w':
194196
LOGGER.info('ANAT: Using T1w as the reference anatomical')
@@ -254,7 +256,7 @@ def init_infant_anat_fit_wf(
254256
msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer')
255257

256258
workflow.connect([
257-
(anat_buffer, outputnode, [
259+
(anat_preproc_buffer, outputnode, [
258260
('anat_preproc', 'anat_preproc'),
259261
]),
260262
(refined_buffer, outputnode, [
@@ -886,6 +888,15 @@ def init_infant_anat_fit_wf(
886888
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
887889
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]
888890

891+
if config.workflow.norm_csf:
892+
csf_norm_wf = init_csf_norm_wf()
893+
894+
workflow.connect([
895+
(anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
896+
(seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]),
897+
(csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
898+
]) # fmt:skip
899+
889900
if templates:
890901
LOGGER.info(f'ANAT Stage 5: Preparing normalization workflow for {templates}')
891902
register_template_wf = init_register_template_wf(
@@ -901,7 +912,9 @@ def init_infant_anat_fit_wf(
901912

902913
workflow.connect([
903914
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
904-
(anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]),
915+
(anat_preproc_buffer, register_template_wf, [
916+
('anat_preproc', 'inputnode.moving_image'),
917+
]),
905918
(refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]),
906919
(sourcefile_buffer, ds_template_registration_wf, [
907920
('anat_source_files', 'inputnode.source_files')
@@ -1094,7 +1107,7 @@ def init_infant_anat_fit_wf(
10941107
(seg_buffer, refinement_wf, [
10951108
('ants_segs', 'inputnode.ants_segs'), # TODO: Verify this is the same as dseg
10961109
]),
1097-
(anat_buffer, applyrefined, [('anat_preproc', 'in_file')]),
1110+
(anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]),
10981111
(refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]),
10991112
(refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]),
11001113
(applyrefined, refined_buffer, [('out_file', 'anat_brain')]),
@@ -1372,6 +1385,14 @@ def init_infant_single_anat_fit_wf(
13721385
name='anat_buffer',
13731386
)
13741387

1388+
# Additional buffer if CSF normalization is used
1389+
anat_preproc_buffer = pe.Node(
1390+
niu.IdentityInterface(fields=['anat_preproc']),
1391+
name='anat_preproc_buffer',
1392+
)
1393+
if not config.workflow.norm_csf:
1394+
workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc')
1395+
13751396
aseg_buffer = pe.Node(
13761397
niu.IdentityInterface(fields=['anat_aseg']),
13771398
name='aseg_buffer',
@@ -1411,7 +1432,7 @@ def init_infant_single_anat_fit_wf(
14111432
msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer')
14121433

14131434
workflow.connect([
1414-
(anat_buffer, outputnode, [
1435+
(anat_preproc_buffer, outputnode, [
14151436
('anat_preproc', 'anat_preproc'),
14161437
]),
14171438
(refined_buffer, outputnode, [
@@ -1712,6 +1733,15 @@ def init_infant_single_anat_fit_wf(
17121733
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
17131734
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]
17141735

1736+
if config.workflow.norm_csf:
1737+
csf_norm_wf = init_csf_norm_wf()
1738+
1739+
workflow.connect([
1740+
(anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
1741+
(seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]),
1742+
(csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
1743+
]) # fmt:skip
1744+
17151745
if templates:
17161746
LOGGER.info(f'ANAT Stage 4: Preparing normalization workflow for {templates}')
17171747
register_template_wf = init_register_template_wf(
@@ -1727,7 +1757,9 @@ def init_infant_single_anat_fit_wf(
17271757

17281758
workflow.connect([
17291759
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
1730-
(anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]),
1760+
(anat_preproc_buffer, register_template_wf, [
1761+
('anat_preproc', 'inputnode.moving_image'),
1762+
]),
17311763
(refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]),
17321764
(sourcefile_buffer, ds_template_registration_wf, [
17331765
('anat_source_files', 'inputnode.source_files')
@@ -1909,7 +1941,7 @@ def init_infant_single_anat_fit_wf(
19091941
(seg_buffer, refinement_wf, [
19101942
('ants_segs', 'inputnode.ants_segs'),
19111943
]),
1912-
(anat_buffer, applyrefined, [('anat_preproc', 'in_file')]),
1944+
(anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]),
19131945
(refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]),
19141946
(refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]),
19151947
(applyrefined, refined_buffer, [('out_file', 'anat_brain')]),

0 commit comments

Comments
 (0)