Skip to content

Commit b9efa6f

Browse files
authored
Merge pull request #419 from mgxd/enh/spatial-norm
FEAT: Option to normalize CSF prior to template registration
2 parents 9476035 + 8981735 commit b9efa6f

File tree

7 files changed

+161
-26
lines changed

7 files changed

+161
-26
lines changed

nibabies/cli/parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,11 @@ def _str_none(val):
753753
default=16,
754754
help='Frame to start head motion estimation on BOLD.',
755755
)
756+
g_baby.add_argument(
757+
'--norm-csf',
758+
action='store_true',
759+
help='Replace low intensity voxels in CSF mask with average',
760+
)
756761
return parser
757762

758763

nibabies/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,8 @@ class workflow(_Config):
578578
"""Run FreeSurfer ``recon-all`` with the ``-logitudinal`` flag."""
579579
medial_surface_nan = None
580580
"""Fill medial surface with :abbr:`NaNs (not-a-number)` when sampling."""
581+
norm_csf = False
582+
"""Replace low intensity voxels in CSF mask with average."""
581583
project_goodvoxels = False
582584
"""Exclude voxels with locally high coefficient of variation from sampling."""
583585
regressors_all_comps = None

nibabies/workflows/anatomical/fit.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from nibabies import config
3737
from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf
3838
from nibabies.workflows.anatomical.outputs import init_anat_reports_wf
39-
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf
39+
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf, init_csf_norm_wf
4040
from nibabies.workflows.anatomical.registration import init_coregistration_wf
4141
from nibabies.workflows.anatomical.segmentation import init_segmentation_wf
4242
from nibabies.workflows.anatomical.surfaces import init_mcribs_dhcp_wf
@@ -184,6 +184,14 @@ def init_infant_anat_fit_wf(
184184
name='anat_buffer',
185185
)
186186

187+
# Additional buffer if CSF normalization is used
188+
anat_preproc_buffer = pe.Node(
189+
niu.IdentityInterface(fields=['anat_preproc']),
190+
name='anat_preproc_buffer',
191+
)
192+
if not config.workflow.norm_csf:
193+
workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc')
194+
187195
if reference_anat == 'T1w':
188196
LOGGER.info('ANAT: Using T1w as the reference anatomical')
189197
workflow.connect([
@@ -248,7 +256,7 @@ def init_infant_anat_fit_wf(
248256
msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer')
249257

250258
workflow.connect([
251-
(anat_buffer, outputnode, [
259+
(anat_preproc_buffer, outputnode, [
252260
('anat_preproc', 'anat_preproc'),
253261
]),
254262
(refined_buffer, outputnode, [
@@ -637,24 +645,6 @@ def init_infant_anat_fit_wf(
637645
(binarize_t2w, t2w_buffer, [('out_file', 't2w_mask')]),
638646
]) # fmt:skip
639647
else:
640-
# Check whether we can convert a previously computed T2w mask
641-
# or need to run the atlas based brain extraction
642-
643-
# if t1w_mask:
644-
# LOGGER.info('ANAT T1w mask will be transformed into T2w space')
645-
# transform_t1w_mask = pe.Node(
646-
# ApplyTransforms(interpolation='MultiLabel'),
647-
# name='transform_t1w_mask',
648-
# )
649-
650-
# workflow.connect([
651-
# (t1w_buffer, transform_t1w_mask, [('t1w_mask', 'input_image')]),
652-
# (coreg_buffer, transform_t1w_mask, [('t1w2t2w_xfm', 'transforms')]),
653-
# (transform_t1w_mask, apply_t2w_mask, [('output_image', 'in_mask')]),
654-
# (t2w_buffer, apply_t1w_mask, [('t2w_preproc', 'in_file')]),
655-
# # TODO: Unsure about this connection^
656-
# ]) # fmt:skip
657-
# else:
658648
LOGGER.info('ANAT Atlas-based brain mask will be calculated on the T2w')
659649
brain_extraction_wf = init_infant_brain_extraction_wf(
660650
omp_nthreads=omp_nthreads,
@@ -898,6 +888,15 @@ def init_infant_anat_fit_wf(
898888
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
899889
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]
900890

891+
if config.workflow.norm_csf:
892+
csf_norm_wf = init_csf_norm_wf()
893+
894+
workflow.connect([
895+
(anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
896+
(seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]),
897+
(csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
898+
]) # fmt:skip
899+
901900
if templates:
902901
LOGGER.info(f'ANAT Stage 5: Preparing normalization workflow for {templates}')
903902
register_template_wf = init_register_template_wf(
@@ -913,7 +912,9 @@ def init_infant_anat_fit_wf(
913912

914913
workflow.connect([
915914
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
916-
(anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]),
915+
(anat_preproc_buffer, register_template_wf, [
916+
('anat_preproc', 'inputnode.moving_image'),
917+
]),
917918
(refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]),
918919
(sourcefile_buffer, ds_template_registration_wf, [
919920
('anat_source_files', 'inputnode.source_files')
@@ -1106,7 +1107,7 @@ def init_infant_anat_fit_wf(
11061107
(seg_buffer, refinement_wf, [
11071108
('ants_segs', 'inputnode.ants_segs'), # TODO: Verify this is the same as dseg
11081109
]),
1109-
(anat_buffer, applyrefined, [('anat_preproc', 'in_file')]),
1110+
(anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]),
11101111
(refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]),
11111112
(refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]),
11121113
(applyrefined, refined_buffer, [('out_file', 'anat_brain')]),
@@ -1384,6 +1385,14 @@ def init_infant_single_anat_fit_wf(
13841385
name='anat_buffer',
13851386
)
13861387

1388+
# Additional buffer if CSF normalization is used
1389+
anat_preproc_buffer = pe.Node(
1390+
niu.IdentityInterface(fields=['anat_preproc']),
1391+
name='anat_preproc_buffer',
1392+
)
1393+
if not config.workflow.norm_csf:
1394+
workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc')
1395+
13871396
aseg_buffer = pe.Node(
13881397
niu.IdentityInterface(fields=['anat_aseg']),
13891398
name='aseg_buffer',
@@ -1423,7 +1432,7 @@ def init_infant_single_anat_fit_wf(
14231432
msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer')
14241433

14251434
workflow.connect([
1426-
(anat_buffer, outputnode, [
1435+
(anat_preproc_buffer, outputnode, [
14271436
('anat_preproc', 'anat_preproc'),
14281437
]),
14291438
(refined_buffer, outputnode, [
@@ -1724,6 +1733,15 @@ def init_infant_single_anat_fit_wf(
17241733
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
17251734
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]
17261735

1736+
if config.workflow.norm_csf:
1737+
csf_norm_wf = init_csf_norm_wf()
1738+
1739+
workflow.connect([
1740+
(anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
1741+
(seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]),
1742+
(csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
1743+
]) # fmt:skip
1744+
17271745
if templates:
17281746
LOGGER.info(f'ANAT Stage 4: Preparing normalization workflow for {templates}')
17291747
register_template_wf = init_register_template_wf(
@@ -1739,7 +1757,9 @@ def init_infant_single_anat_fit_wf(
17391757

17401758
workflow.connect([
17411759
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
1742-
(anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]),
1760+
(anat_preproc_buffer, register_template_wf, [
1761+
('anat_preproc', 'inputnode.moving_image'),
1762+
]),
17431763
(refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]),
17441764
(sourcefile_buffer, ds_template_registration_wf, [
17451765
('anat_source_files', 'inputnode.source_files')
@@ -1921,7 +1941,7 @@ def init_infant_single_anat_fit_wf(
19211941
(seg_buffer, refinement_wf, [
19221942
('ants_segs', 'inputnode.ants_segs'),
19231943
]),
1924-
(anat_buffer, applyrefined, [('anat_preproc', 'in_file')]),
1944+
(anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]),
19251945
(refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]),
19261946
(refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]),
19271947
(applyrefined, refined_buffer, [('out_file', 'anat_brain')]),

nibabies/workflows/anatomical/preproc.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,56 @@ def init_anat_preproc_wf(
6666
(final_clip, outputnode, [('out_file', 'anat_preproc')]),
6767
]) # fmt:skip
6868
return wf
69+
70+
71+
def init_csf_norm_wf(name: str = 'csf_norm_wf') -> LiterateWorkflow:
72+
"""Replace low intensity voxels within the CSF mask with the median value."""
73+
74+
workflow = LiterateWorkflow(name=name)
75+
workflow.__desc__ = (
76+
'The CSF mask was used to normalize the anatomical template by the median of voxels '
77+
'within the mask.'
78+
)
79+
inputnode = pe.Node(
80+
niu.IdentityInterface(fields=['anat_preproc', 'anat_tpms']),
81+
name='inputnode',
82+
)
83+
outputnode = pe.Node(niu.IdentityInterface(fields=['anat_preproc']), name='outputnode')
84+
85+
# select CSF from BIDS-ordered list (GM, WM, CSF)
86+
select_csf = pe.Node(niu.Select(index=2), name='select_csf')
87+
norm_csf = pe.Node(niu.Function(function=_normalize_roi), name='norm_csf')
88+
89+
workflow.connect([
90+
(inputnode, select_csf, [('anat_tpms', 'inlist')]),
91+
(select_csf, norm_csf, [('out', 'mask_file')]),
92+
(inputnode, norm_csf, [('anat_preproc', 'in_file')]),
93+
(norm_csf, outputnode, [('out', 'anat_preproc')]),
94+
]) # fmt:skip
95+
96+
return workflow
97+
98+
99+
def _normalize_roi(in_file, mask_file, threshold=0.2, out_file=None):
100+
"""Normalize low intensity voxels that fall within a given mask."""
101+
import nibabel as nb
102+
import numpy as np
103+
104+
img = nb.load(in_file)
105+
img_data = np.asanyarray(img.dataobj)
106+
mask_img = nb.load(mask_file)
107+
# binary mask
108+
bin_mask = np.asanyarray(mask_img.dataobj) > threshold
109+
mask_data = bin_mask * img_data
110+
masked_data = mask_data[mask_data > 0]
111+
112+
median = np.median(masked_data).astype(masked_data.dtype)
113+
normed_data = np.maximum(img_data, bin_mask * median)
114+
115+
oimg = img.__class__(normed_data, img.affine, img.header)
116+
if not out_file:
117+
from nipype.utils.filemanip import fname_presuffix
118+
119+
out_file = fname_presuffix(in_file, suffix='normed')
120+
oimg.to_filename(out_file)
121+
return out_file

nibabies/workflows/anatomical/tests/__init__.py

Whitespace-only changes.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import typing as ty
2+
from pathlib import Path
3+
4+
import nibabel as nb
5+
import numpy as np
6+
import pytest
7+
8+
from nibabies.workflows.anatomical.preproc import _normalize_roi, init_csf_norm_wf
9+
10+
EXPECTED_CSF_NORM = np.array([[[10, 73], [73, 29]], [[77, 80], [6, 16]]], dtype='uint8')
11+
12+
13+
@pytest.fixture
14+
def csf_norm_data(tmp_path) -> ty.Generator[tuple[Path, list[Path]], None, None]:
15+
np.random.seed(10)
16+
17+
in_file = tmp_path / 'input.nii.gz'
18+
data = np.random.randint(1, 101, size=(2, 2, 2), dtype='uint8')
19+
img = nb.Nifti1Image(data, np.eye(4))
20+
img.to_filename(in_file)
21+
22+
masks = []
23+
for tpm in ('gm', 'wm', 'csf'):
24+
name = tmp_path / f'{tpm}.nii.gz'
25+
binmask = data > np.random.randint(10, 90)
26+
masked = (binmask * 1).astype('uint8')
27+
mask = nb.Nifti1Image(masked, img.affine)
28+
mask.to_filename(name)
29+
masks.append(name)
30+
31+
yield in_file, masks
32+
33+
in_file.unlink()
34+
for m in masks:
35+
m.unlink()
36+
37+
38+
def test_csf_norm_wf(tmp_path, csf_norm_data):
39+
anat, tpms = csf_norm_data
40+
wf = init_csf_norm_wf()
41+
wf.base_dir = tmp_path
42+
43+
wf.inputs.inputnode.anat_preproc = anat
44+
wf.inputs.inputnode.anat_tpms = tpms
45+
46+
# verify workflow runs
47+
wf.run()
48+
49+
# verify function works as expected
50+
outfile = _normalize_roi(anat, tpms[2])
51+
assert np.array_equal(
52+
np.asanyarray(nb.load(outfile).dataobj),
53+
EXPECTED_CSF_NORM,
54+
)
55+
Path(outfile).unlink()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"requests",
3535
"sdcflows >= 2.10.0",
3636
# "smriprep >= 0.16.1",
37-
"smriprep @ git+https://github.com/nipreps/smriprep.git@master",
37+
"smriprep @ git+https://github.com/nipreps/smriprep.git@dev-nibabies",
3838
"tedana >= 23.0.2",
3939
"templateflow >= 24.2.0",
4040
"toml",

0 commit comments

Comments
 (0)