diff --git a/nibabies/cli/parser.py b/nibabies/cli/parser.py index 9e0d0165..e7b1b470 100644 --- a/nibabies/cli/parser.py +++ b/nibabies/cli/parser.py @@ -753,6 +753,11 @@ def _str_none(val): default=16, help='Frame to start head motion estimation on BOLD.', ) + g_baby.add_argument( + '--norm-csf', + action='store_true', + help='Replace low intensity voxels in CSF mask with average', + ) return parser diff --git a/nibabies/config.py b/nibabies/config.py index 70f309cb..f3bb698f 100644 --- a/nibabies/config.py +++ b/nibabies/config.py @@ -578,6 +578,8 @@ class workflow(_Config): """Run FreeSurfer ``recon-all`` with the ``-logitudinal`` flag.""" medial_surface_nan = None """Fill medial surface with :abbr:`NaNs (not-a-number)` when sampling.""" + norm_csf = False + """Replace low intensity voxels in CSF mask with average.""" project_goodvoxels = False """Exclude voxels with locally high coefficient of variation from sampling.""" regressors_all_comps = None diff --git a/nibabies/workflows/anatomical/fit.py b/nibabies/workflows/anatomical/fit.py index 9ebf091b..d0fe552d 100644 --- a/nibabies/workflows/anatomical/fit.py +++ b/nibabies/workflows/anatomical/fit.py @@ -36,7 +36,7 @@ from nibabies import config from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf from nibabies.workflows.anatomical.outputs import init_anat_reports_wf -from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf +from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf, init_csf_norm_wf from nibabies.workflows.anatomical.registration import init_coregistration_wf from nibabies.workflows.anatomical.segmentation import init_segmentation_wf from nibabies.workflows.anatomical.surfaces import init_mcribs_dhcp_wf @@ -184,6 +184,14 @@ def init_infant_anat_fit_wf( name='anat_buffer', ) + # Additional buffer if CSF normalization is used + anat_preproc_buffer = pe.Node( + niu.IdentityInterface(fields=['anat_preproc']), + name='anat_preproc_buffer', + ) + if not config.workflow.norm_csf: + workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc') + if reference_anat == 'T1w': LOGGER.info('ANAT: Using T1w as the reference anatomical') workflow.connect([ @@ -248,7 +256,7 @@ def init_infant_anat_fit_wf( msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer') workflow.connect([ - (anat_buffer, outputnode, [ + (anat_preproc_buffer, outputnode, [ ('anat_preproc', 'anat_preproc'), ]), (refined_buffer, outputnode, [ @@ -637,24 +645,6 @@ def init_infant_anat_fit_wf( (binarize_t2w, t2w_buffer, [('out_file', 't2w_mask')]), ]) # fmt:skip else: - # Check whether we can convert a previously computed T2w mask - # or need to run the atlas based brain extraction - - # if t1w_mask: - # LOGGER.info('ANAT T1w mask will be transformed into T2w space') - # transform_t1w_mask = pe.Node( - # ApplyTransforms(interpolation='MultiLabel'), - # name='transform_t1w_mask', - # ) - - # workflow.connect([ - # (t1w_buffer, transform_t1w_mask, [('t1w_mask', 'input_image')]), - # (coreg_buffer, transform_t1w_mask, [('t1w2t2w_xfm', 'transforms')]), - # (transform_t1w_mask, apply_t2w_mask, [('output_image', 'in_mask')]), - # (t2w_buffer, apply_t1w_mask, [('t2w_preproc', 'in_file')]), - # # TODO: Unsure about this connection^ - # ]) # fmt:skip - # else: LOGGER.info('ANAT Atlas-based brain mask will be calculated on the T2w') brain_extraction_wf = init_infant_brain_extraction_wf( omp_nthreads=omp_nthreads, @@ -898,6 +888,15 @@ def init_infant_anat_fit_wf( anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()] std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()] + if config.workflow.norm_csf: + csf_norm_wf = init_csf_norm_wf() + + workflow.connect([ + (anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]), + (seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]), + (csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]), + ]) # fmt:skip + if templates: LOGGER.info(f'ANAT Stage 5: Preparing normalization workflow for {templates}') register_template_wf = init_register_template_wf( @@ -913,7 +912,9 @@ def init_infant_anat_fit_wf( workflow.connect([ (inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]), - (anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]), + (anat_preproc_buffer, register_template_wf, [ + ('anat_preproc', 'inputnode.moving_image'), + ]), (refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]), (sourcefile_buffer, ds_template_registration_wf, [ ('anat_source_files', 'inputnode.source_files') @@ -1106,7 +1107,7 @@ def init_infant_anat_fit_wf( (seg_buffer, refinement_wf, [ ('ants_segs', 'inputnode.ants_segs'), # TODO: Verify this is the same as dseg ]), - (anat_buffer, applyrefined, [('anat_preproc', 'in_file')]), + (anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]), (refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]), (refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]), (applyrefined, refined_buffer, [('out_file', 'anat_brain')]), @@ -1384,6 +1385,14 @@ def init_infant_single_anat_fit_wf( name='anat_buffer', ) + # Additional buffer if CSF normalization is used + anat_preproc_buffer = pe.Node( + niu.IdentityInterface(fields=['anat_preproc']), + name='anat_preproc_buffer', + ) + if not config.workflow.norm_csf: + workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc') + aseg_buffer = pe.Node( niu.IdentityInterface(fields=['anat_aseg']), name='aseg_buffer', @@ -1423,7 +1432,7 @@ def init_infant_single_anat_fit_wf( msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer') workflow.connect([ - (anat_buffer, outputnode, [ + (anat_preproc_buffer, outputnode, [ ('anat_preproc', 'anat_preproc'), ]), (refined_buffer, outputnode, [ @@ -1724,6 +1733,15 @@ def init_infant_single_anat_fit_wf( anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()] std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()] + if config.workflow.norm_csf: + csf_norm_wf = init_csf_norm_wf() + + workflow.connect([ + (anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]), + (seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]), + (csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]), + ]) # fmt:skip + if templates: LOGGER.info(f'ANAT Stage 4: Preparing normalization workflow for {templates}') register_template_wf = init_register_template_wf( @@ -1739,7 +1757,9 @@ def init_infant_single_anat_fit_wf( workflow.connect([ (inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]), - (anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]), + (anat_preproc_buffer, register_template_wf, [ + ('anat_preproc', 'inputnode.moving_image'), + ]), (refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]), (sourcefile_buffer, ds_template_registration_wf, [ ('anat_source_files', 'inputnode.source_files') @@ -1921,7 +1941,7 @@ def init_infant_single_anat_fit_wf( (seg_buffer, refinement_wf, [ ('ants_segs', 'inputnode.ants_segs'), ]), - (anat_buffer, applyrefined, [('anat_preproc', 'in_file')]), + (anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]), (refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]), (refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]), (applyrefined, refined_buffer, [('out_file', 'anat_brain')]), diff --git a/nibabies/workflows/anatomical/preproc.py b/nibabies/workflows/anatomical/preproc.py index 7cbc08aa..12fbe814 100644 --- a/nibabies/workflows/anatomical/preproc.py +++ b/nibabies/workflows/anatomical/preproc.py @@ -66,3 +66,56 @@ def init_anat_preproc_wf( (final_clip, outputnode, [('out_file', 'anat_preproc')]), ]) # fmt:skip return wf + + +def init_csf_norm_wf(name: str = 'csf_norm_wf') -> LiterateWorkflow: + """Replace low intensity voxels within the CSF mask with the median value.""" + + workflow = LiterateWorkflow(name=name) + workflow.__desc__ = ( + 'The CSF mask was used to normalize the anatomical template by the median of voxels ' + 'within the mask.' + ) + inputnode = pe.Node( + niu.IdentityInterface(fields=['anat_preproc', 'anat_tpms']), + name='inputnode', + ) + outputnode = pe.Node(niu.IdentityInterface(fields=['anat_preproc']), name='outputnode') + + # select CSF from BIDS-ordered list (GM, WM, CSF) + select_csf = pe.Node(niu.Select(index=2), name='select_csf') + norm_csf = pe.Node(niu.Function(function=_normalize_roi), name='norm_csf') + + workflow.connect([ + (inputnode, select_csf, [('anat_tpms', 'inlist')]), + (select_csf, norm_csf, [('out', 'mask_file')]), + (inputnode, norm_csf, [('anat_preproc', 'in_file')]), + (norm_csf, outputnode, [('out', 'anat_preproc')]), + ]) # fmt:skip + + return workflow + + +def _normalize_roi(in_file, mask_file, threshold=0.2, out_file=None): + """Normalize low intensity voxels that fall within a given mask.""" + import nibabel as nb + import numpy as np + + img = nb.load(in_file) + img_data = np.asanyarray(img.dataobj) + mask_img = nb.load(mask_file) + # binary mask + bin_mask = np.asanyarray(mask_img.dataobj) > threshold + mask_data = bin_mask * img_data + masked_data = mask_data[mask_data > 0] + + median = np.median(masked_data).astype(masked_data.dtype) + normed_data = np.maximum(img_data, bin_mask * median) + + oimg = img.__class__(normed_data, img.affine, img.header) + if not out_file: + from nipype.utils.filemanip import fname_presuffix + + out_file = fname_presuffix(in_file, suffix='normed') + oimg.to_filename(out_file) + return out_file diff --git a/nibabies/workflows/anatomical/tests/__init__.py b/nibabies/workflows/anatomical/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nibabies/workflows/anatomical/tests/test_preproc.py b/nibabies/workflows/anatomical/tests/test_preproc.py new file mode 100644 index 00000000..32087016 --- /dev/null +++ b/nibabies/workflows/anatomical/tests/test_preproc.py @@ -0,0 +1,55 @@ +import typing as ty +from pathlib import Path + +import nibabel as nb +import numpy as np +import pytest + +from nibabies.workflows.anatomical.preproc import _normalize_roi, init_csf_norm_wf + +EXPECTED_CSF_NORM = np.array([[[10, 73], [73, 29]], [[77, 80], [6, 16]]], dtype='uint8') + + +@pytest.fixture +def csf_norm_data(tmp_path) -> ty.Generator[tuple[Path, list[Path]], None, None]: + np.random.seed(10) + + in_file = tmp_path / 'input.nii.gz' + data = np.random.randint(1, 101, size=(2, 2, 2), dtype='uint8') + img = nb.Nifti1Image(data, np.eye(4)) + img.to_filename(in_file) + + masks = [] + for tpm in ('gm', 'wm', 'csf'): + name = tmp_path / f'{tpm}.nii.gz' + binmask = data > np.random.randint(10, 90) + masked = (binmask * 1).astype('uint8') + mask = nb.Nifti1Image(masked, img.affine) + mask.to_filename(name) + masks.append(name) + + yield in_file, masks + + in_file.unlink() + for m in masks: + m.unlink() + + +def test_csf_norm_wf(tmp_path, csf_norm_data): + anat, tpms = csf_norm_data + wf = init_csf_norm_wf() + wf.base_dir = tmp_path + + wf.inputs.inputnode.anat_preproc = anat + wf.inputs.inputnode.anat_tpms = tpms + + # verify workflow runs + wf.run() + + # verify function works as expected + outfile = _normalize_roi(anat, tpms[2]) + assert np.array_equal( + np.asanyarray(nb.load(outfile).dataobj), + EXPECTED_CSF_NORM, + ) + Path(outfile).unlink() diff --git a/pyproject.toml b/pyproject.toml index 650ec51b..5316ab6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "requests", "sdcflows >= 2.10.0", # "smriprep >= 0.16.1", - "smriprep @ git+https://github.com/nipreps/smriprep.git@master", + "smriprep @ git+https://github.com/nipreps/smriprep.git@dev-nibabies", "tedana >= 23.0.2", "templateflow >= 24.2.0", "toml",