Skip to content

Commit ec18eaf

Browse files
authored
Merge pull request #459 from mgxd/fix/preproc-hdr
ENH: Verify derivatives are compatible with anatomical reference
2 parents 1e6afd4 + d216490 commit ec18eaf

File tree

4 files changed

+165
-44
lines changed

4 files changed

+165
-44
lines changed

nibabies/workflows/anatomical/fit.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
from nibabies.interfaces import DerivativesDataSink
3939
from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf
4040
from nibabies.workflows.anatomical.outputs import init_anat_reports_wf, init_coreg_report_wf
41-
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf, init_csf_norm_wf
41+
from nibabies.workflows.anatomical.preproc import (
42+
init_anat_preproc_wf,
43+
init_conform_derivative_wf,
44+
init_csf_norm_wf,
45+
)
4246
from nibabies.workflows.anatomical.registration import (
4347
init_concat_registrations_wf,
4448
init_coregistration_wf,
@@ -172,11 +176,11 @@ def init_infant_anat_fit_wf(
172176

173177
# Stage 2 - Anatomicals
174178
t1w_buffer = pe.Node(
175-
niu.IdentityInterface(fields=['t1w_preproc', 't1w_maskt1w_brain']),
179+
niu.IdentityInterface(fields=['t1w_preproc', 't1w_mask', 't1w_brain']),
176180
name='t1w_buffer',
177181
)
178182
t2w_buffer = pe.Node(
179-
niu.IdentityInterface(fields=['t2w_preproc', 't2w_maskt2w_brain', 't2w_probmap']),
183+
niu.IdentityInterface(fields=['t2w_preproc', 't2w_mask', 't2w_brain', 't2w_probmap']),
180184
name='t2w_buffer',
181185
)
182186
anat_buffer = pe.Node(
@@ -323,6 +327,7 @@ def init_infant_anat_fit_wf(
323327

324328
t1w_preproc = precomputed.get('t1w_preproc')
325329
t2w_preproc = precomputed.get('t2w_preproc')
330+
anat_preproc = precomputed.get(f'{anat}_preproc')
326331

327332
# Stage 1: Conform & valid T1w/T2w images
328333
# Note: Since stage 1 & 2 are tightly knit together, it may be more intuitive
@@ -575,21 +580,29 @@ def init_infant_anat_fit_wf(
575580
'A pre-computed T1w brain mask was provided as input and used throughout the '
576581
'workflow.'
577582
)
578-
t1w_buffer.inputs.t1w_mask = t1w_mask
579-
apply_t1w_mask.inputs.in_mask = t1w_mask
580583
workflow.connect(apply_t1w_mask, 'out_file', t1w_buffer, 't1w_brain')
581584

582585
if not t1w_preproc:
586+
# Ensure compatibility with T1w template
587+
conform_t1w_mask_wf = init_conform_derivative_wf(
588+
in_file=t1w_mask, name='conform_t1w_mask_wf'
589+
)
590+
583591
LOGGER.info('ANAT Skipping skull-strip, INU-correction only')
584592
t1w_n4_wf = init_anat_preproc_wf(name='t1w_n4_wf')
585593
workflow.connect([
594+
(t1w_validate, conform_t1w_mask_wf, [('out_file', 'inputnode.ref_file')]),
595+
(conform_t1w_mask_wf, t1w_buffer, [('outputnode.out_file', 't1w_mask')]),
596+
(conform_t1w_mask_wf, apply_t1w_mask, [('outputnode.out_file', 'in_mask')]),
586597
(t1w_validate, t1w_n4_wf, [('out_file', 'inputnode.in_anat')]),
587598
(t1w_n4_wf, t1w_buffer, [('outputnode.anat_preproc', 't1w_preproc')]),
588599
(t1w_n4_wf, apply_t1w_mask, [('outputnode.anat_preproc', 'in_file')]),
589600
]) # fmt:skip
590601
else:
591602
LOGGER.info('ANAT Skipping T1w masking')
592603
workflow.connect(t1w_validate, 'out_file', apply_t1w_mask, 'in_file')
604+
t1w_buffer.inputs.t1w_mask = t1w_mask
605+
apply_t1w_mask.inputs.in_mask = t1w_mask
593606

594607
# T2w masking logic:
595608
#
@@ -701,21 +714,30 @@ def init_infant_anat_fit_wf(
701714
'A pre-computed T2w brain mask was provided as input and used throughout the '
702715
'workflow.'
703716
)
704-
t2w_buffer.inputs.t2w_mask = t2w_mask
705-
apply_t2w_mask.inputs.in_mask = t2w_mask
706717
workflow.connect(apply_t2w_mask, 'out_file', t2w_buffer, 't2w_brain')
707718

708719
if not t2w_preproc:
720+
# Ensure compatibility with T2w template
721+
conform_t2w_mask_wf = init_conform_derivative_wf(
722+
in_file=t2w_mask,
723+
name='conform_t2w_mask_wf',
724+
)
725+
709726
LOGGER.info('ANAT Skipping skull-strip, INU-correction only')
710727
t2w_n4_wf = init_anat_preproc_wf(name='t2w_n4_wf')
711728
workflow.connect([
729+
(t2w_validate, conform_t2w_mask_wf, [('out_file', 'inputnode.ref_file')]),
730+
(conform_t2w_mask_wf, t2w_buffer, [('outputnode.out_file', 't2w_mask')]),
731+
(conform_t2w_mask_wf, apply_t2w_mask, [('outputnode.out_file', 'in_mask')]),
712732
(t2w_validate, t2w_n4_wf, [('out_file', 'inputnode.in_anat')]),
713733
(t2w_n4_wf, t2w_buffer, [('outputnode.anat_preproc', 't2w_preproc')]),
714734
(t2w_n4_wf, apply_t2w_mask, [('outputnode.anat_preproc', 'in_file')]),
715735
]) # fmt:skip
716736
else:
717737
LOGGER.info('ANAT Skipping T2w masking')
718738
workflow.connect(t2w_validate, 'out_file', apply_t2w_mask, 'in_file')
739+
t2w_buffer.inputs.t2w_mask = t2w_mask
740+
apply_t2w_mask.inputs.in_mask = t2w_mask
719741

720742
# Stage 3: Coregistration
721743
t1w2t2w_xfm = precomputed.get('t1w2t2w_xfm')
@@ -819,7 +841,19 @@ def init_infant_anat_fit_wf(
819841

820842
if anat_aseg:
821843
LOGGER.info('ANAT Found precomputed anatomical segmentation')
822-
aseg_buffer.inputs.anat_aseg = anat_aseg
844+
# Ensure compatibility with anatomical template
845+
if not anat_preproc:
846+
conform_aseg_wf = init_conform_derivative_wf(
847+
in_file=anat_aseg,
848+
name='conform_aseg_wf',
849+
)
850+
851+
workflow.connect([
852+
(anat_buffer, conform_aseg_wf, [('anat_preproc', 'inputnode.ref_file')]),
853+
(conform_aseg_wf, aseg_buffer, [('outputnode.out_file', 'anat_aseg')]),
854+
]) # fmt:skip
855+
else:
856+
aseg_buffer.inputs.anat_aseg = anat_aseg
823857

824858
if not (anat_dseg and anat_tpms):
825859
LOGGER.info('ANAT Stage 4: Tissue segmentation')
@@ -1714,27 +1748,47 @@ def init_infant_single_anat_fit_wf(
17141748
else:
17151749
LOGGER.info(f'ANAT Found {reference_anat} brain mask')
17161750
desc += 'A pre-computed brain mask was provided as input and used throughout the workflow.'
1717-
anat_buffer.inputs.anat_mask = anat_mask
1718-
apply_mask.inputs.in_mask = anat_mask
17191751
workflow.connect(apply_mask, 'out_file', anat_buffer, 'anat_brain')
17201752

17211753
if not anat_preproc:
1754+
conform_anat_mask_wf = init_conform_derivative_wf(
1755+
in_file=anat_mask,
1756+
name='conform_anat_mask_wf',
1757+
)
1758+
17221759
LOGGER.info('ANAT Skipping skull-strip, INU-correction only')
17231760
anat_n4_wf = init_anat_preproc_wf(name='anat_n4_wf')
17241761
workflow.connect([
1762+
(anat_validate, conform_anat_mask_wf, [('out_file', 'inputnode.ref_file')]),
1763+
(conform_anat_mask_wf, anat_buffer, [('outputnode.out_file', 'anat_mask')]),
1764+
(conform_anat_mask_wf, apply_mask, [('outputnode.out_file', 'in_mask')]),
17251765
(anat_validate, anat_n4_wf, [('out_file', 'inputnode.in_anat')]),
17261766
(anat_n4_wf, anat_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
17271767
(anat_n4_wf, apply_mask, [('outputnode.anat_preproc', 'in_file')]),
17281768
]) # fmt:skip
17291769
else:
17301770
LOGGER.info(f'ANAT Skipping {reference_anat} masking')
17311771
workflow.connect(anat_validate, 'out_file', apply_mask, 'in_file')
1772+
anat_buffer.inputs.anat_mask = anat_mask
1773+
apply_mask.inputs.in_mask = anat_mask
17321774

17331775
# Stage 3: Segmentation
17341776
seg_method = 'jlf' if config.execution.segmentation_atlases_dir else 'fast'
17351777
if anat_aseg:
17361778
LOGGER.info('ANAT Found precomputed anatomical segmentation')
1737-
aseg_buffer.inputs.anat_aseg = anat_aseg
1779+
# Ensure compatibility with anatomical template
1780+
if not anat_preproc:
1781+
conform_aseg_wf = init_conform_derivative_wf(
1782+
in_file=anat_aseg,
1783+
name='conform_aseg_wf',
1784+
)
1785+
1786+
workflow.connect([
1787+
(anat_buffer, conform_aseg_wf, [('anat_preproc', 'inputnode.ref_file')]),
1788+
(conform_aseg_wf, aseg_buffer, [('outputnode.out_file', 'anat_aseg')]),
1789+
]) # fmt:skip
1790+
else:
1791+
aseg_buffer.inputs.anat_aseg = anat_aseg
17381792

17391793
if not (anat_dseg and anat_tpms):
17401794
LOGGER.info('ANAT Stage 3: Tissue segmentation')

nibabies/workflows/anatomical/preproc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,41 @@ def init_csf_norm_wf(name: str = 'csf_norm_wf') -> Workflow:
9898
return workflow
9999

100100

101+
def init_conform_derivative_wf(
102+
*, in_file: str = None, name: str = 'conform_derivative_wf'
103+
) -> pe.Workflow:
104+
"""
105+
Ensure derivatives share the same space as anatomical references.
106+
107+
This workflow is used when a derivative is provided without a reference.
108+
"""
109+
from niworkflows.interfaces.header import MatchHeader
110+
from niworkflows.interfaces.images import Conform, TemplateDimensions
111+
112+
workflow = pe.Workflow(name=name)
113+
inputnode = pe.Node(niu.IdentityInterface(fields=['in_file', 'ref_file']), name='inputnode')
114+
inputnode.inputs.in_file = in_file
115+
outputnode = pe.Node(niu.IdentityInterface(fields=['out_file']), name='outputnode')
116+
117+
ref_dims = pe.Node(TemplateDimensions(), name='ref_dims')
118+
conform = pe.Node(Conform(), name='conform')
119+
# Avoid mismatch tolerance from input
120+
match_header = pe.Node(MatchHeader(), name='match_header')
121+
122+
workflow.connect([
123+
(inputnode, ref_dims, [('ref_file', 'anat_list')]),
124+
(ref_dims, conform, [
125+
('target_zooms', 'target_zooms'),
126+
('target_shape', 'target_shape'),
127+
]),
128+
(inputnode, conform, [('in_file', 'in_file')]),
129+
(conform, match_header, [('out_file', 'in_file')]),
130+
(inputnode, match_header, [('ref_file', 'reference')]),
131+
(match_header, outputnode, [('out_file', 'out_file')]),
132+
]) # fmt:skip
133+
return workflow
134+
135+
101136
def _normalize_roi(in_file, mask_file, threshold=0.2, out_file=None):
102137
"""Normalize low intensity voxels that fall within a given mask."""
103138
import nibabel as nb

nibabies/workflows/anatomical/surfaces.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from niworkflows.interfaces.freesurfer import (
1616
PatchedRobustRegister as RobustRegister,
1717
)
18-
from niworkflows.interfaces.header import MatchHeader
1918
from niworkflows.interfaces.morphology import BinaryDilation
2019
from niworkflows.interfaces.patches import FreeSurferSource
2120
from smriprep.interfaces.freesurfer import MakeMidthickness
@@ -128,9 +127,6 @@ def init_mcribs_surface_recon_wf(
128127
mask_dil = pe.Node(BinaryDilation(radius=3), name='mask_dil')
129128
mask_las = pe.Node(ReorientImage(target_orientation='LAS'), name='mask_las')
130129

131-
# N4 has low tolerance for mismatch between input / mask
132-
match_header = pe.Node(MatchHeader(), name='match_header')
133-
134130
# N4BiasCorrection occurs in MCRIBTissueSegMCRIBS (which is skipped)
135131
# Run it (with mask to rescale intensities) prior injection
136132
n4_mcribs = pe.Node(
@@ -182,9 +178,7 @@ def init_mcribs_surface_recon_wf(
182178
('subjects_dir', 'subjects_dir'),
183179
('subject_id', 'subject_id')]),
184180
(t2w_las, n4_mcribs, [('out_file', 'input_image')]),
185-
(mask_las, match_header, [('out_file', 'in_file')]),
186-
(t2w_las, match_header, [('out_file', 'reference')]),
187-
(match_header, n4_mcribs, [('out_file', 'mask_image')]),
181+
(mask_las, n4_mcribs, [('out_file', 'mask_image')]),
188182
(n4_mcribs, mcribs_recon, [('output_image', 't2w_file')]),
189183
(seg_las, mcribs_recon, [('out_file', 'segmentation_file')]),
190184
(inputnode, mcribs_postrecon, [
Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,93 @@
1-
import typing as ty
21
from pathlib import Path
32

43
import nibabel as nb
54
import numpy as np
65
import pytest
6+
from nibabel.orientations import aff2axcodes
77

8-
from nibabies.workflows.anatomical.preproc import _normalize_roi, init_csf_norm_wf
8+
from nibabies.workflows.anatomical.preproc import (
9+
_normalize_roi,
10+
init_conform_derivative_wf,
11+
init_csf_norm_wf,
12+
)
913

10-
EXPECTED_CSF_NORM = np.array([[[10, 73], [73, 29]], [[77, 80], [6, 16]]], dtype='uint8')
14+
EXPECTED_CSF_NORM = np.array([[[49, 75], [23, 75]], [[77, 80], [33, 3]]], dtype='uint8')
1115

1216

1317
@pytest.fixture
14-
def csf_norm_data(tmp_path) -> ty.Generator[tuple[Path, list[Path]], None, None]:
15-
np.random.seed(10)
16-
17-
in_file = tmp_path / 'input.nii.gz'
18-
data = np.random.randint(1, 101, size=(2, 2, 2), dtype='uint8')
18+
def anat_file(tmp_path):
19+
data = np.array([[[49, 73], [23, 73]], [[77, 80], [33, 3]]], dtype='uint8')
1920
img = nb.Nifti1Image(data, np.eye(4))
20-
img.to_filename(in_file)
21+
out = tmp_path / 'input.nii.gz'
22+
img.to_filename(out)
23+
return out
24+
2125

22-
masks = []
23-
for tpm in ('gm', 'wm', 'csf'):
26+
def test_csf_norm_wf(tmp_path, anat_file):
27+
tpms = []
28+
for tpm, thresh in (('gm', 25), ('wm', 75), ('csf', 50)):
2429
name = tmp_path / f'{tpm}.nii.gz'
25-
binmask = data > np.random.randint(10, 90)
30+
anat_img = nb.load(anat_file)
31+
anat_data = np.asanyarray(nb.load(anat_file).dataobj)
32+
33+
binmask = anat_data > thresh
2634
masked = (binmask * 1).astype('uint8')
27-
mask = nb.Nifti1Image(masked, img.affine)
35+
mask = nb.Nifti1Image(masked, anat_img.affine)
2836
mask.to_filename(name)
29-
masks.append(name)
30-
31-
yield in_file, masks
37+
tpms.append(name)
3238

33-
in_file.unlink()
34-
for m in masks:
35-
m.unlink()
36-
37-
38-
def test_csf_norm_wf(tmp_path, csf_norm_data):
39-
anat, tpms = csf_norm_data
4039
wf = init_csf_norm_wf()
4140
wf.base_dir = tmp_path
42-
43-
wf.inputs.inputnode.anat_preproc = anat
41+
wf.inputs.inputnode.anat_preproc = anat_file
4442
wf.inputs.inputnode.anat_tpms = tpms
4543

4644
# verify workflow runs
4745
wf.run()
4846

4947
# verify function works as expected
50-
outfile = _normalize_roi(anat, tpms[2])
48+
outfile = _normalize_roi(anat_file, tpms[2])
5149
assert np.array_equal(
5250
np.asanyarray(nb.load(outfile).dataobj),
5351
EXPECTED_CSF_NORM,
5452
)
5553
Path(outfile).unlink()
54+
55+
56+
@pytest.mark.parametrize(
57+
('affine_mismatch', 'ornt_mismatch'),
58+
[
59+
(False, False),
60+
(True, False),
61+
(False, True),
62+
(True, True),
63+
],
64+
)
65+
def test_conform_derivative_wf(tmp_path, anat_file, affine_mismatch, ornt_mismatch):
66+
deriv = tmp_path / 'mask.nii.gz'
67+
ref_img = nb.load(anat_file)
68+
aff = ref_img.affine.copy()
69+
if affine_mismatch:
70+
# Alter affine slightly
71+
aff[:3, :3] += 0.01
72+
assert not np.array_equal(aff, ref_img.affine)
73+
74+
img = ref_img.__class__(ref_img.dataobj, affine=aff)
75+
if ornt_mismatch:
76+
from niworkflows.interfaces.nibabel import reorient_image
77+
78+
img = reorient_image(img, target_ornt='LPI')
79+
assert aff2axcodes(img.affine) != aff2axcodes(ref_img.affine)
80+
81+
img.to_filename(deriv)
82+
wf = init_conform_derivative_wf(in_file=deriv)
83+
wf.base_dir = tmp_path
84+
wf.inputs.inputnode.ref_file = anat_file
85+
86+
wf.run()
87+
88+
output = list((tmp_path / 'conform_derivative_wf' / 'match_header').glob('*.nii.gz'))
89+
assert output
90+
out_file = output[0]
91+
out_img = nb.load(out_file)
92+
assert np.array_equal(out_img.affine, ref_img.affine)
93+
assert aff2axcodes(out_img.affine) == aff2axcodes(ref_img.affine)

0 commit comments

Comments
 (0)