diff --git a/sdcflows/workflows/base.py b/sdcflows/workflows/base.py index ae6cc2229f..3c3d97ebe8 100644 --- a/sdcflows/workflows/base.py +++ b/sdcflows/workflows/base.py @@ -39,6 +39,7 @@ def init_fmap_preproc_wf( sloppy=False, debug=False, name="fmap_preproc_wf", + **kwargs, ): """ Create and combine estimator workflows. @@ -110,6 +111,7 @@ def init_fmap_preproc_wf( omp_nthreads=omp_nthreads, debug=debug, sloppy=sloppy, + **kwargs, ) source_files = [ str(f.path) for f in estimator.sources if f.suffix not in ("T1w", "T2w") diff --git a/sdcflows/workflows/fit/fieldmap.py b/sdcflows/workflows/fit/fieldmap.py index 2050bbf4ed..93307f00d3 100644 --- a/sdcflows/workflows/fit/fieldmap.py +++ b/sdcflows/workflows/fit/fieldmap.py @@ -29,7 +29,14 @@ INPUT_FIELDS = ("magnitude", "fieldmap") -def init_fmap_wf(omp_nthreads=1, sloppy=False, debug=False, mode="phasediff", name="fmap_wf"): +def init_fmap_wf( + omp_nthreads=1, + sloppy=False, + debug=False, + mode="phasediff", + name="fmap_wf", + **kwargs, +): """ Estimate the fieldmap based on a field-mapping MRI acquisition. diff --git a/sdcflows/workflows/fit/pepolar.py b/sdcflows/workflows/fit/pepolar.py index 1d6ed518a8..b0f6bd59fc 100644 --- a/sdcflows/workflows/fit/pepolar.py +++ b/sdcflows/workflows/fit/pepolar.py @@ -40,6 +40,7 @@ def init_topup_wf( sloppy=False, debug=False, name="pepolar_estimate_wf", + **kwargs, ): """ Create the PEPOLAR field estimation workflow based on FSL's ``topup``. diff --git a/sdcflows/workflows/fit/syn.py b/sdcflows/workflows/fit/syn.py index db00592fbb..1d4b7a55c4 100644 --- a/sdcflows/workflows/fit/syn.py +++ b/sdcflows/workflows/fit/syn.py @@ -23,8 +23,10 @@ """ Estimating the susceptibility distortions without fieldmaps. """ + from nipype.pipeline import engine as pe from nipype.interfaces import utility as niu +from nipype.interfaces.base import Undefined from niworkflows.engine.workflows import LiterateWorkflow as Workflow from ... import data @@ -46,6 +48,8 @@ def init_syn_sdc_wf( debug=False, name="syn_sdc_wf", omp_nthreads=1, + sd_prior=True, + **kwargs, ): """ Build the *fieldmap-less* susceptibility-distortion estimation workflow. @@ -117,7 +121,6 @@ def init_syn_sdc_wf( FixHeaderRegistration as Registration, ) from niworkflows.interfaces.nibabel import ( - Binarize, IntensityClip, RegridToZooms, ) @@ -171,7 +174,6 @@ def init_syn_sdc_wf( name="warp_dir", ) warp_dir.inputs.nlevels = 2 - atlas_msk = pe.Node(Binarize(thresh_low=atlas_threshold), name="atlas_msk") anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk") amask2epi = pe.Node( ApplyTransforms(interpolation="MultiLabel", transforms="identity"), @@ -208,7 +210,7 @@ def init_syn_sdc_wf( ) fixed_masks = pe.Node( - niu.Merge(3), + niu.Merge(2), name="fixed_masks", mem_gb=DEFAULT_MEMORY_MIN_GB, run_without_submitting=True, @@ -220,9 +222,7 @@ def init_syn_sdc_wf( # SyN Registration Core syn = pe.Node( - Registration( - from_file=data.load(f"sd_syn{'_sloppy' * sloppy}.json") - ), + Registration(from_file=data.load(f"sd_syn{'_sloppy' * sloppy}.json")), name="syn", n_procs=omp_nthreads, ) @@ -233,9 +233,7 @@ def init_syn_sdc_wf( syn.inputs.args = "--write-interval-volumes 2" # Extract the corresponding fieldmap in Hz - extract_field = pe.Node( - DisplacementsField2Fieldmap(), name="extract_field" - ) + extract_field = pe.Node(DisplacementsField2Fieldmap(), name="extract_field") unwarp = pe.Node(ApplyCoeffsField(jacobian=False), name="unwarp") @@ -267,7 +265,6 @@ def init_syn_sdc_wf( workflow.connect([ (inputnode, readout_time, [(("epi_ref", _pop), "in_file"), (("epi_ref", _pull), "metadata")]), - (inputnode, atlas_msk, [("sd_prior", "in_file")]), (inputnode, clip_epi, [(("epi_ref", _pop), "in_file")]), (inputnode, unwarp, [(("epi_ref", _pop), "in_data")]), (inputnode, amask2epi, [("epi_mask", "reference_image")]), @@ -293,7 +290,6 @@ def init_syn_sdc_wf( (lap_epi, lap_epi_norm, [("output_image", "in_file")]), (lap_epi_norm, epi_merge, [("out", "in2")]), (find_zooms, zooms_epi, [("out", "zooms")]), - (atlas_msk, fixed_masks, [("out_mask", "in3")]), (anat_dilmsk, amask2epi, [("out_file", "input_image")]), (amask2epi, epi_umask, [("output_image", "in2")]), (readout_time, warp_dir, [("pe_direction", "pe_dir")]), @@ -331,6 +327,7 @@ def init_syn_preprocessing_wf( omp_nthreads=1, auto_bold_nss=False, t1w_inversion=False, + sd_prior=True, ): """ Prepare EPI references and co-registration to anatomical for SyN. @@ -356,6 +353,8 @@ def init_syn_preprocessing_wf( of BOLD images. t1w_inversion : :obj:`bool` Run T1w intensity inversion so that it looks more like a T2 contrast. + sd_prior : :obj:`bool` + Enable using a prior map to regularize the SyN cost function. Inputs ------ @@ -426,26 +425,6 @@ def init_syn_preprocessing_wf( deob_epi = pe.Node(Deoblique(), name="deob_epi") - # Mapping & preparing prior knowledge - # Concatenate transform files: - # 1) MNI -> anat; 2) ATLAS -> MNI - transform_list = pe.Node( - niu.Merge(3), - name="transform_list", - mem_gb=DEFAULT_MEMORY_MIN_GB, - run_without_submitting=True, - ) - transform_list.inputs.in3 = data.load("fmap_atlas_2_MNI152NLin2009cAsym_affine.mat") - prior2epi = pe.Node( - ApplyTransforms( - invert_transform_flags=[True, False, False], - input_image=str(data.load("fmap_atlas.nii.gz")), - ), - name="prior2epi", - n_procs=omp_nthreads, - mem_gb=0.3, - ) - anat2epi = pe.Node( ApplyTransforms(invert_transform_flags=[True]), name="anat2epi", @@ -502,9 +481,44 @@ def _remove_first_mask(in_file): sampling_ref = pe.Node(GenerateSamplingReference(), name="sampling_ref") - # fmt:off + if sd_prior: + # Mapping & preparing prior knowledge + # Concatenate transform files: + # 1) MNI -> anat; 2) ATLAS -> MNI + transform_list = pe.Node( + niu.Merge(3), + name="transform_list", + mem_gb=DEFAULT_MEMORY_MIN_GB, + run_without_submitting=True, + ) + transform_list.inputs.in3 = data.load( + "fmap_atlas_2_MNI152NLin2009cAsym_affine.mat" + ) + prior2epi = pe.Node( + ApplyTransforms( + invert_transform_flags=[True, False, False], + input_image=str(data.load("fmap_atlas.nii.gz")), + ), + name="prior2epi", + n_procs=omp_nthreads, + mem_gb=0.3, + ) + + workflow.connect([ + (inputnode, transform_list, [("std2anat_xfm", "in2")]), + (epi2anat, transform_list, [("forward_transforms", "in1")]), + (transform_list, prior2epi, [("out", "transforms")]), + (sampling_ref, prior2epi, [("out_file", "reference_image")]), + (prior2epi, outputnode, [("output_image", "sd_prior")]), + ]) # fmt:skip + + else: + # no prior to be used + # MG: Future goal is to allow using alternative mappings + # i.e. in the case of infants, where priors change depending on development + outputnode.inputs.sd_prior = Undefined + workflow.connect([ - (inputnode, transform_list, [("std2anat_xfm", "in2")]), (inputnode, epi_reference_wf, [("in_epis", "inputnode.in_files")]), (inputnode, merge_output, [("in_meta", "meta_list")]), (inputnode, anat_dilmsk, [("mask_anat", "in_file")]), @@ -523,9 +537,6 @@ def _remove_first_mask(in_file): (epi_dilmsk, epi2anat, [ (("out_file", _remove_first_mask), "moving_image_masks")]), (deob_epi, sampling_ref, [("out_file", "fixed_image")]), - (epi2anat, transform_list, [("forward_transforms", "in1")]), - (transform_list, prior2epi, [("out", "transforms")]), - (sampling_ref, prior2epi, [("out_file", "reference_image")]), (ref_anat, anat2epi, [("output_image", "input_image")]), (epi2anat, anat2epi, [("forward_transforms", "transforms")]), (sampling_ref, anat2epi, [("out_file", "reference_image")]), @@ -536,9 +547,7 @@ def _remove_first_mask(in_file): (mask_dtype, outputnode, [("out", "anat_mask")]), (merge_output, outputnode, [("out", "epi_ref")]), (epi_brain, outputnode, [("out_mask", "epi_mask")]), - (prior2epi, outputnode, [("output_image", "sd_prior")]), - ]) - # fmt:on + ]) # fmt:skip if debug: from niworkflows.interfaces.nibabel import RegridToZooms diff --git a/sdcflows/workflows/fit/tests/test_syn.py b/sdcflows/workflows/fit/tests/test_syn.py index 262890d03a..0db8bcbd7b 100644 --- a/sdcflows/workflows/fit/tests/test_syn.py +++ b/sdcflows/workflows/fit/tests/test_syn.py @@ -30,7 +30,8 @@ @pytest.mark.veryslow @pytest.mark.slow -def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode): +@pytest.mark.parametrize("sd_prior", [True, False]) +def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior): """Build and run an SDC-SyN workflow.""" derivs_path = datadir / "ds000054" / "derivatives" smriprep = derivs_path / "smriprep-0.6" / "sub-100185" / "anat" @@ -42,6 +43,7 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode): debug=sloppy_mode, auto_bold_nss=True, t1w_inversion=True, + sd_prior=sd_prior, ) prep_wf.inputs.inputnode.in_epis = [ str( @@ -72,7 +74,12 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode): smriprep / "sub-100185_desc-brain_mask.nii.gz" ) - syn_wf = init_syn_sdc_wf(debug=sloppy_mode, sloppy=sloppy_mode, omp_nthreads=4) + syn_wf = init_syn_sdc_wf( + debug=sloppy_mode, + sloppy=sloppy_mode, + omp_nthreads=4, + sd_prior=sd_prior, + ) # fmt: off wf.connect([