diff --git a/sdcflows/data/sd_syn.json b/sdcflows/data/sd_syn.json index d317223db7..46aa07e7a2 100644 --- a/sdcflows/data/sd_syn.json +++ b/sdcflows/data/sd_syn.json @@ -14,7 +14,7 @@ "shrink_factors": [ [ 1, 1 ], [ 1 ] ], "sigma_units": [ "vox", "vox" ], "smoothing_sigmas": [ [ 2, 0 ], [ 0 ] ], - "transform_parameters": [ [ 0.8, 6.0, 3.0 ], [ 0.8, 2.0, 1.0 ] ], + "transform_parameters": [ [ 0.8, 6.0, 10.0 ], [ 0.8, 2.0, 0.5 ] ], "transforms": [ "SyN", "SyN" ], "use_histogram_matching": [ true, true ], "winsorize_lower_quantile": 0.001, diff --git a/sdcflows/data/sd_syn_sloppy.json b/sdcflows/data/sd_syn_sloppy.json index a2aab7c042..460d76226d 100644 --- a/sdcflows/data/sd_syn_sloppy.json +++ b/sdcflows/data/sd_syn_sloppy.json @@ -14,7 +14,7 @@ "shrink_factors": [ [ 1, 1 ], [ 1 ] ], "sigma_units": [ "vox", "vox" ], "smoothing_sigmas": [ [ 2, 0 ], [ 0 ] ], - "transform_parameters": [ [ 0.8, 6.0, 3.0 ], [ 0.8, 2.0, 1.0 ] ], + "transform_parameters": [ [ 0.8, 6.0, 10.0 ], [ 0.8, 2.0, 0.5 ] ], "transforms": [ "SyN", "SyN" ], "use_histogram_matching": [ true, true ], "verbose": true, diff --git a/sdcflows/workflows/fit/syn.py b/sdcflows/workflows/fit/syn.py index debbaf0cb5..1ee021d495 100644 --- a/sdcflows/workflows/fit/syn.py +++ b/sdcflows/workflows/fit/syn.py @@ -23,6 +23,7 @@ """ Estimating the susceptibility distortions without fieldmaps. """ +import json from nipype.pipeline import engine as pe from nipype.interfaces import utility as niu @@ -214,9 +215,14 @@ def init_syn_sdc_wf( find_zooms = pe.Node(niu.Function(function=_adjust_zooms), name="find_zooms") zooms_epi = pe.Node(RegridToZooms(), name="zooms_epi") + syn_config = data.load(f"sd_syn{'_sloppy' * sloppy}.json") + + vox_params = pe.Node(niu.Function(function=_mm2vox), name="vox_params") + vox_params.inputs.registration_config = json.loads(syn_config.read_text()) + # SyN Registration Core syn = pe.Node( - Registration(from_file=data.load(f"sd_syn{'_sloppy' * sloppy}.json")), + Registration(from_file=syn_config), name="syn", n_procs=omp_nthreads, ) @@ -277,6 +283,7 @@ def init_syn_sdc_wf( ("sd_prior", "in2")]), (inputnode, anat_dilmsk, [("anat_mask", "in_file")]), (inputnode, warp_dir, [("anat_ref", "fixed_image")]), + (inputnode, vox_params, [("anat_ref", "fixed_image")]), (inputnode, anat_merge, [("anat_ref", "in1")]), (inputnode, lap_anat, [("anat_ref", "op1")]), (inputnode, find_zooms, [("anat_ref", "in_anat"), @@ -295,11 +302,15 @@ def init_syn_sdc_wf( (anat_dilmsk, amask2epi, [("out_file", "input_image")]), (amask2epi, epi_umask, [("output_image", "in2")]), (readout_time, warp_dir, [("pe_direction", "pe_dir")]), + (readout_time, vox_params, [("pe_direction", "pe_dir")]), + (clip_epi, warp_dir, [("out_file", "moving_image")]), + (clip_epi, vox_params, [("out_file", "moving_image")]), (warp_dir, syn, [("out", "restrict_deformation")]), (anat_merge, syn, [("out", "fixed_image")]), (fixed_masks, syn, [("out", "fixed_image_masks")]), (epi_merge, syn, [("out", "moving_image")]), (moving_masks, syn, [("out", "moving_image_masks")]), + (vox_params, syn, [("out", "transform_parameters")]), (syn, extract_field, [(("forward_transforms", _pop), "transform")]), (clip_epi, extract_field, [("out_file", "epi")]), (readout_time, extract_field, [("readout_time", "ro_time"), @@ -583,25 +594,52 @@ def _remove_first_mask(in_file): return workflow -def _warp_dir(fixed_image, pe_dir, nlevels=3): +def _warp_dir(moving_image, fixed_image, pe_dir, nlevels=2): """Extract the ``restrict_deformation`` argument from metadata.""" import numpy as np import nibabel as nb - img = nb.load(fixed_image) + moving = nb.load(moving_image) + fixed = nb.load(fixed_image) - if np.any(nb.affines.obliquity(img.affine) > 0.05): + if np.any(nb.affines.obliquity(fixed.affine) > 0.05): from nipype import logging logging.getLogger("nipype.interface").warn( "Running fieldmap-less registration on an oblique dataset" ) - vs = nb.affines.voxel_sizes(img.affine) - order = np.around(np.abs(img.affine[:3, :3] / vs)) - retval = order @ [1 if pe_dir[0] == ax else 0.1 for ax in "ijk"] + moving_axcodes = nb.aff2axcodes(moving.affine, ["RR", "AA", "SS"]) + moving_pe_axis = moving_axcodes["ijk".index(pe_dir[0])] + + fixed_axcodes = nb.aff2axcodes(fixed.affine, ["RR", "AA", "SS"]) + + deformation = [0.1, 0.1, 0.1] + deformation[fixed_axcodes.index(moving_pe_axis)] = 1.0 + + return nlevels * [deformation] + + +def _mm2vox(moving_image, fixed_image, pe_dir, registration_config): + import nibabel as nb + + params = registration_config['transform_parameters'] + + moving = nb.load(moving_image) + # Use duplicate axcodes to ignore sign + moving_axcodes = nb.aff2axcodes(moving.affine, ["RR", "AA", "SS"]) + moving_pe_axis = moving_axcodes["ijk".index(pe_dir[0])] + + fixed = nb.load(fixed_image) + fixed_axcodes = nb.aff2axcodes(fixed.affine, ["RR", "AA", "SS"]) + + zooms = nb.affines.voxel_sizes(fixed.affine) + pe_res = zooms[fixed_axcodes.index(moving_pe_axis)] - return nlevels * [retval.tolist()] + return [ + [*level_params[:2], level_params[2] / pe_res] + for level_params in params + ] def _merge_meta(epi_ref, meta_list): diff --git a/sdcflows/workflows/fit/tests/test_syn.py b/sdcflows/workflows/fit/tests/test_syn.py index 78dd3e339a..44fc6e93a8 100644 --- a/sdcflows/workflows/fit/tests/test_syn.py +++ b/sdcflows/workflows/fit/tests/test_syn.py @@ -23,10 +23,21 @@ """Test fieldmap-less SDC-SyN.""" import json + +import numpy as np +import nibabel as nb import pytest from nipype.pipeline import engine as pe -from ..syn import init_syn_sdc_wf, init_syn_preprocessing_wf, _adjust_zooms, _set_dtype +from .... import data +from ..syn import ( + init_syn_sdc_wf, + init_syn_preprocessing_wf, + _adjust_zooms, + _set_dtype, + _mm2vox, + _warp_dir, +) @pytest.mark.veryslow @@ -254,3 +265,90 @@ def test_ensure_dtype(in_dtype, out_dtype, tmpdir): assert out_file == f"{in_dtype}.nii.gz" else: assert out_file == f"{in_dtype}_{out_dtype}.nii.gz" + + +def axcodes2aff(axcodes): + """Return an affine matrix from axis codes.""" + return nb.orientations.inv_ornt_aff( + nb.orientations.ornt_transform( + nb.orientations.axcodes2ornt("RAS"), + nb.orientations.axcodes2ornt(axcodes), + ), + (10, 10, 10), + ) + + +@pytest.mark.parametrize( + ("fixed_ornt", "moving_ornt", "ijk", "index"), + [ + ("RAS", "RAS", "i", 0), + ("RAS", "RAS", "j", 1), + ("RAS", "RAS", "k", 2), + ("RAS", "PSL", "i", 1), + ("RAS", "PSL", "j", 2), + ("RAS", "PSL", "k", 0), + ("PSL", "RAS", "i", 2), + ("PSL", "RAS", "j", 0), + ("PSL", "RAS", "k", 1), + ], +) +def test_mm2vox(tmp_path, fixed_ornt, moving_ornt, ijk, index): + fixed_path = tmp_path / "fixed.nii.gz" + moving_path = tmp_path / "moving.nii.gz" + + # Use separate zooms to make identifying the conversion easier + fixed_aff = np.diag((2, 3, 4, 1)) + nb.save( + nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(fixed_ornt) @ fixed_aff), + fixed_path, + ) + nb.save( + nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(moving_ornt)), + moving_path, + ) + + config = json.loads(data.load.readable("sd_syn.json").read_text()) + + params = config["transform_parameters"] + mm_values = np.array([level[2] for level in params]) + + vox_params = _mm2vox(str(moving_path), str(fixed_path), ijk, config) + vox_values = [level[2] for level in vox_params] + assert [ + mm_level[:2] == vox_level[:2] for mm_level, vox_level in zip(params, vox_params) + ] + assert np.array_equal(vox_values, mm_values / [2, 3, 4][index]) + + +@pytest.mark.parametrize( + ("fixed_ornt", "moving_ornt", "ijk", "index"), + [ + ("RAS", "RAS", "i", 0), + ("RAS", "RAS", "j", 1), + ("RAS", "RAS", "k", 2), + ("RAS", "PSL", "i", 1), + ("RAS", "PSL", "j", 2), + ("RAS", "PSL", "k", 0), + ("PSL", "RAS", "i", 2), + ("PSL", "RAS", "j", 0), + ("PSL", "RAS", "k", 1), + ], +) +def test_warp_dir(tmp_path, fixed_ornt, moving_ornt, ijk, index): + fixed_path = tmp_path / "fixed.nii.gz" + moving_path = tmp_path / "moving.nii.gz" + + nb.save( + nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(fixed_ornt)), + fixed_path, + ) + nb.save( + nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(moving_ornt)), + moving_path, + ) + + for nlevels in range(1, 3): + deformations = _warp_dir(str(moving_path), str(fixed_path), ijk, nlevels) + assert len(deformations) == nlevels + for val in deformations: + assert val == [1.0 if i == index else 0.1 for i in range(3)]