Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdcflows/data/sd_syn.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"shrink_factors": [ [ 1, 1 ], [ 1 ] ],
"sigma_units": [ "vox", "vox" ],
"smoothing_sigmas": [ [ 2, 0 ], [ 0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 3.0 ], [ 0.8, 2.0, 1.0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 10.0 ], [ 0.8, 2.0, 0.5 ] ],
"transforms": [ "SyN", "SyN" ],
"use_histogram_matching": [ true, true ],
"winsorize_lower_quantile": 0.001,
Expand Down
2 changes: 1 addition & 1 deletion sdcflows/data/sd_syn_sloppy.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"shrink_factors": [ [ 1, 1 ], [ 1 ] ],
"sigma_units": [ "vox", "vox" ],
"smoothing_sigmas": [ [ 2, 0 ], [ 0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 3.0 ], [ 0.8, 2.0, 1.0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 10.0 ], [ 0.8, 2.0, 0.5 ] ],
"transforms": [ "SyN", "SyN" ],
"use_histogram_matching": [ true, true ],
"verbose": true,
Expand Down
54 changes: 46 additions & 8 deletions sdcflows/workflows/fit/syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
Estimating the susceptibility distortions without fieldmaps.
"""
import json

from nipype.pipeline import engine as pe
from nipype.interfaces import utility as niu
Expand Down Expand Up @@ -214,9 +215,14 @@ def init_syn_sdc_wf(
find_zooms = pe.Node(niu.Function(function=_adjust_zooms), name="find_zooms")
zooms_epi = pe.Node(RegridToZooms(), name="zooms_epi")

syn_config = data.load(f"sd_syn{'_sloppy' * sloppy}.json")

vox_params = pe.Node(niu.Function(function=_mm2vox), name="vox_params")
vox_params.inputs.registration_config = json.loads(syn_config.read_text())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need to read the config outside the Registration node, we might want to extract the number of levels from it (e.g., len(config_dict['transforms']) instead of hard-code the magic number (of 2). It will be a bit trickier for handling masks, but a bit more robust, IMHO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But probably for another PR btw


# SyN Registration Core
syn = pe.Node(
Registration(from_file=data.load(f"sd_syn{'_sloppy' * sloppy}.json")),
Registration(from_file=syn_config),
name="syn",
n_procs=omp_nthreads,
)
Expand Down Expand Up @@ -277,6 +283,7 @@ def init_syn_sdc_wf(
("sd_prior", "in2")]),
(inputnode, anat_dilmsk, [("anat_mask", "in_file")]),
(inputnode, warp_dir, [("anat_ref", "fixed_image")]),
(inputnode, vox_params, [("anat_ref", "fixed_image")]),
(inputnode, anat_merge, [("anat_ref", "in1")]),
(inputnode, lap_anat, [("anat_ref", "op1")]),
(inputnode, find_zooms, [("anat_ref", "in_anat"),
Expand All @@ -295,11 +302,15 @@ def init_syn_sdc_wf(
(anat_dilmsk, amask2epi, [("out_file", "input_image")]),
(amask2epi, epi_umask, [("output_image", "in2")]),
(readout_time, warp_dir, [("pe_direction", "pe_dir")]),
(readout_time, vox_params, [("pe_direction", "pe_dir")]),
(clip_epi, warp_dir, [("out_file", "moving_image")]),
(clip_epi, vox_params, [("out_file", "moving_image")]),
(warp_dir, syn, [("out", "restrict_deformation")]),
(anat_merge, syn, [("out", "fixed_image")]),
(fixed_masks, syn, [("out", "fixed_image_masks")]),
(epi_merge, syn, [("out", "moving_image")]),
(moving_masks, syn, [("out", "moving_image_masks")]),
(vox_params, syn, [("out", "transform_parameters")]),
(syn, extract_field, [(("forward_transforms", _pop), "transform")]),
(clip_epi, extract_field, [("out_file", "epi")]),
(readout_time, extract_field, [("readout_time", "ro_time"),
Expand Down Expand Up @@ -583,25 +594,52 @@ def _remove_first_mask(in_file):
return workflow


def _warp_dir(fixed_image, pe_dir, nlevels=3):
def _warp_dir(moving_image, fixed_image, pe_dir, nlevels=2):
"""Extract the ``restrict_deformation`` argument from metadata."""
import numpy as np
import nibabel as nb

img = nb.load(fixed_image)
moving = nb.load(moving_image)
fixed = nb.load(fixed_image)

if np.any(nb.affines.obliquity(img.affine) > 0.05):
if np.any(nb.affines.obliquity(fixed.affine) > 0.05):
from nipype import logging

logging.getLogger("nipype.interface").warn(
"Running fieldmap-less registration on an oblique dataset"
)

vs = nb.affines.voxel_sizes(img.affine)
order = np.around(np.abs(img.affine[:3, :3] / vs))
retval = order @ [1 if pe_dir[0] == ax else 0.1 for ax in "ijk"]
moving_axcodes = nb.aff2axcodes(moving.affine, ["RR", "AA", "SS"])
moving_pe_axis = moving_axcodes["ijk".index(pe_dir[0])]

fixed_axcodes = nb.aff2axcodes(fixed.affine, ["RR", "AA", "SS"])

deformation = [0.1, 0.1, 0.1]
deformation[fixed_axcodes.index(moving_pe_axis)] = 1.0

return nlevels * [deformation]


def _mm2vox(moving_image, fixed_image, pe_dir, registration_config):
import nibabel as nb

params = registration_config['transform_parameters']

moving = nb.load(moving_image)
# Use duplicate axcodes to ignore sign
moving_axcodes = nb.aff2axcodes(moving.affine, ["RR", "AA", "SS"])
moving_pe_axis = moving_axcodes["ijk".index(pe_dir[0])]

fixed = nb.load(fixed_image)
fixed_axcodes = nb.aff2axcodes(fixed.affine, ["RR", "AA", "SS"])

zooms = nb.affines.voxel_sizes(fixed.affine)
pe_res = zooms[fixed_axcodes.index(moving_pe_axis)]

return nlevels * [retval.tolist()]
return [
[*level_params[:2], level_params[2] / pe_res]
for level_params in params
]


def _merge_meta(epi_ref, meta_list):
Expand Down
100 changes: 99 additions & 1 deletion sdcflows/workflows/fit/tests/test_syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@
"""Test fieldmap-less SDC-SyN."""

import json

import numpy as np
import nibabel as nb
import pytest
from nipype.pipeline import engine as pe

from ..syn import init_syn_sdc_wf, init_syn_preprocessing_wf, _adjust_zooms, _set_dtype
from .... import data
from ..syn import (
init_syn_sdc_wf,
init_syn_preprocessing_wf,
_adjust_zooms,
_set_dtype,
_mm2vox,
_warp_dir,
)


@pytest.mark.veryslow
Expand Down Expand Up @@ -254,3 +265,90 @@ def test_ensure_dtype(in_dtype, out_dtype, tmpdir):
assert out_file == f"{in_dtype}.nii.gz"
else:
assert out_file == f"{in_dtype}_{out_dtype}.nii.gz"


def axcodes2aff(axcodes):
"""Return an affine matrix from axis codes."""
return nb.orientations.inv_ornt_aff(
nb.orientations.ornt_transform(
nb.orientations.axcodes2ornt("RAS"),
nb.orientations.axcodes2ornt(axcodes),
),
(10, 10, 10),
)


@pytest.mark.parametrize(
("fixed_ornt", "moving_ornt", "ijk", "index"),
[
("RAS", "RAS", "i", 0),
("RAS", "RAS", "j", 1),
("RAS", "RAS", "k", 2),
("RAS", "PSL", "i", 1),
("RAS", "PSL", "j", 2),
("RAS", "PSL", "k", 0),
("PSL", "RAS", "i", 2),
("PSL", "RAS", "j", 0),
("PSL", "RAS", "k", 1),
],
)
def test_mm2vox(tmp_path, fixed_ornt, moving_ornt, ijk, index):
fixed_path = tmp_path / "fixed.nii.gz"
moving_path = tmp_path / "moving.nii.gz"

# Use separate zooms to make identifying the conversion easier
fixed_aff = np.diag((2, 3, 4, 1))
nb.save(
nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(fixed_ornt) @ fixed_aff),
fixed_path,
)
nb.save(
nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(moving_ornt)),
moving_path,
)

config = json.loads(data.load.readable("sd_syn.json").read_text())

params = config["transform_parameters"]
mm_values = np.array([level[2] for level in params])

vox_params = _mm2vox(str(moving_path), str(fixed_path), ijk, config)
vox_values = [level[2] for level in vox_params]
assert [
mm_level[:2] == vox_level[:2] for mm_level, vox_level in zip(params, vox_params)
]
assert np.array_equal(vox_values, mm_values / [2, 3, 4][index])


@pytest.mark.parametrize(
("fixed_ornt", "moving_ornt", "ijk", "index"),
[
("RAS", "RAS", "i", 0),
("RAS", "RAS", "j", 1),
("RAS", "RAS", "k", 2),
("RAS", "PSL", "i", 1),
("RAS", "PSL", "j", 2),
("RAS", "PSL", "k", 0),
("PSL", "RAS", "i", 2),
("PSL", "RAS", "j", 0),
("PSL", "RAS", "k", 1),
],
)
def test_warp_dir(tmp_path, fixed_ornt, moving_ornt, ijk, index):
fixed_path = tmp_path / "fixed.nii.gz"
moving_path = tmp_path / "moving.nii.gz"

nb.save(
nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(fixed_ornt)),
fixed_path,
)
nb.save(
nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(moving_ornt)),
moving_path,
)

for nlevels in range(1, 3):
deformations = _warp_dir(str(moving_path), str(fixed_path), ijk, nlevels)
assert len(deformations) == nlevels
for val in deformations:
assert val == [1.0 if i == index else 0.1 for i in range(3)]
Loading