Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdcflows/data/sd_syn.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"shrink_factors": [ [ 1, 1 ], [ 1 ] ],
"sigma_units": [ "vox", "vox" ],
"smoothing_sigmas": [ [ 2, 0 ], [ 0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 3.0 ], [ 0.8, 2.0, 1.0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 10.0 ], [ 0.8, 2.0, 0.5 ] ],
"transforms": [ "SyN", "SyN" ],
"use_histogram_matching": [ true, true ],
"winsorize_lower_quantile": 0.001,
Expand Down
2 changes: 1 addition & 1 deletion sdcflows/data/sd_syn_sloppy.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"shrink_factors": [ [ 1, 1 ], [ 1 ] ],
"sigma_units": [ "vox", "vox" ],
"smoothing_sigmas": [ [ 2, 0 ], [ 0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 3.0 ], [ 0.8, 2.0, 1.0 ] ],
"transform_parameters": [ [ 0.8, 6.0, 10.0 ], [ 0.8, 2.0, 0.5 ] ],
"transforms": [ "SyN", "SyN" ],
"use_histogram_matching": [ true, true ],
"verbose": true,
Expand Down
54 changes: 46 additions & 8 deletions sdcflows/workflows/fit/syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""
Estimating the susceptibility distortions without fieldmaps.
"""
import json

from nipype.pipeline import engine as pe
from nipype.interfaces import utility as niu
Expand Down Expand Up @@ -214,9 +215,14 @@
find_zooms = pe.Node(niu.Function(function=_adjust_zooms), name="find_zooms")
zooms_epi = pe.Node(RegridToZooms(), name="zooms_epi")

syn_config = data.load(f"sd_syn{'_sloppy' * sloppy}.json")

vox_params = pe.Node(niu.Function(function=_mm2vox), name="vox_params")
vox_params.inputs.registration_config = json.loads(syn_config.read_text())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need to read the config outside the Registration node, we might want to extract the number of levels from it (e.g., len(config_dict['transforms']) instead of hard-code the magic number (of 2). It will be a bit trickier for handling masks, but a bit more robust, IMHO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But probably for another PR btw


# SyN Registration Core
syn = pe.Node(
Registration(from_file=data.load(f"sd_syn{'_sloppy' * sloppy}.json")),
Registration(from_file=syn_config),
name="syn",
n_procs=omp_nthreads,
)
Expand Down Expand Up @@ -277,6 +283,7 @@
("sd_prior", "in2")]),
(inputnode, anat_dilmsk, [("anat_mask", "in_file")]),
(inputnode, warp_dir, [("anat_ref", "fixed_image")]),
(inputnode, vox_params, [("anat_ref", "fixed_image")]),
(inputnode, anat_merge, [("anat_ref", "in1")]),
(inputnode, lap_anat, [("anat_ref", "op1")]),
(inputnode, find_zooms, [("anat_ref", "in_anat"),
Expand All @@ -295,11 +302,15 @@
(anat_dilmsk, amask2epi, [("out_file", "input_image")]),
(amask2epi, epi_umask, [("output_image", "in2")]),
(readout_time, warp_dir, [("pe_direction", "pe_dir")]),
(readout_time, vox_params, [("pe_direction", "pe_dir")]),
(clip_epi, warp_dir, [("out_file", "moving_image")]),
(clip_epi, vox_params, [("out_file", "moving_image")]),
(warp_dir, syn, [("out", "restrict_deformation")]),
(anat_merge, syn, [("out", "fixed_image")]),
(fixed_masks, syn, [("out", "fixed_image_masks")]),
(epi_merge, syn, [("out", "moving_image")]),
(moving_masks, syn, [("out", "moving_image_masks")]),
(vox_params, syn, [("out", "transform_parameters")]),
(syn, extract_field, [(("forward_transforms", _pop), "transform")]),
(clip_epi, extract_field, [("out_file", "epi")]),
(readout_time, extract_field, [("readout_time", "ro_time"),
Expand Down Expand Up @@ -583,25 +594,52 @@
return workflow


def _warp_dir(fixed_image, pe_dir, nlevels=3):
def _warp_dir(moving_image, fixed_image, pe_dir, nlevels=3):
"""Extract the ``restrict_deformation`` argument from metadata."""
import numpy as np
import nibabel as nb

img = nb.load(fixed_image)
moving = nb.load(moving_image)
fixed = nb.load(fixed_image)

Check warning on line 603 in sdcflows/workflows/fit/syn.py

View check run for this annotation

Codecov / codecov/patch

sdcflows/workflows/fit/syn.py#L602-L603

Added lines #L602 - L603 were not covered by tests

if np.any(nb.affines.obliquity(img.affine) > 0.05):
if np.any(nb.affines.obliquity(fixed.affine) > 0.05):
from nipype import logging

logging.getLogger("nipype.interface").warn(
"Running fieldmap-less registration on an oblique dataset"
)

vs = nb.affines.voxel_sizes(img.affine)
order = np.around(np.abs(img.affine[:3, :3] / vs))
retval = order @ [1 if pe_dir[0] == ax else 0.1 for ax in "ijk"]
moving_axcodes = nb.aff2axcodes(moving.affine, ["RR", "AA", "SS"])
moving_pe_axis = moving_axcodes["ijk".index(pe_dir[0])]

Check warning on line 613 in sdcflows/workflows/fit/syn.py

View check run for this annotation

Codecov / codecov/patch

sdcflows/workflows/fit/syn.py#L612-L613

Added lines #L612 - L613 were not covered by tests

fixed_axcodes = nb.aff2axcodes(fixed.affine, ["RR", "AA", "SS"])

Check warning on line 615 in sdcflows/workflows/fit/syn.py

View check run for this annotation

Codecov / codecov/patch

sdcflows/workflows/fit/syn.py#L615

Added line #L615 was not covered by tests

deformation = [0.1, 0.1, 0.1]
deformation[fixed_axcodes.index(moving_pe_axis)] = 1.0

Check warning on line 618 in sdcflows/workflows/fit/syn.py

View check run for this annotation

Codecov / codecov/patch

sdcflows/workflows/fit/syn.py#L617-L618

Added lines #L617 - L618 were not covered by tests

return nlevels * [deformation]

Check warning on line 620 in sdcflows/workflows/fit/syn.py

View check run for this annotation

Codecov / codecov/patch

sdcflows/workflows/fit/syn.py#L620

Added line #L620 was not covered by tests


def _mm2vox(moving_image, fixed_image, pe_dir, registration_config):
import nibabel as nb

params = registration_config['transform_parameters']

moving = nb.load(moving_image)
# Use duplicate axcodes to ignore sign
moving_axcodes = nb.aff2axcodes(moving.affine, ["RR", "AA", "SS"])
moving_pe_axis = moving_axcodes["ijk".index(pe_dir[0])]

fixed = nb.load(fixed_image)
fixed_axcodes = nb.aff2axcodes(fixed.affine, ["RR", "AA", "SS"])

zooms = nb.affines.voxel_sizes(fixed.affine)
pe_res = zooms[fixed_axcodes.index(moving_pe_axis)]

return nlevels * [retval.tolist()]
return [
[*level_params[:2], level_params[2] / pe_res]
for level_params in params
]


def _merge_meta(epi_ref, meta_list):
Expand Down
65 changes: 64 additions & 1 deletion sdcflows/workflows/fit/tests/test_syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,20 @@
"""Test fieldmap-less SDC-SyN."""

import json

import numpy as np
import nibabel as nb
import pytest
from nipype.pipeline import engine as pe

from ..syn import init_syn_sdc_wf, init_syn_preprocessing_wf, _adjust_zooms, _set_dtype
from .... import data
from ..syn import (
init_syn_sdc_wf,
init_syn_preprocessing_wf,
_adjust_zooms,
_set_dtype,
_mm2vox,
)


@pytest.mark.veryslow
Expand Down Expand Up @@ -254,3 +264,56 @@ def test_ensure_dtype(in_dtype, out_dtype, tmpdir):
assert out_file == f"{in_dtype}.nii.gz"
else:
assert out_file == f"{in_dtype}_{out_dtype}.nii.gz"


def axcodes2aff(axcodes):
"""Return an affine matrix from axis codes."""
return nb.orientations.inv_ornt_aff(
nb.orientations.ornt_transform(
nb.orientations.axcodes2ornt("RAS"),
nb.orientations.axcodes2ornt(axcodes),
),
(10, 10, 10),
)


@pytest.mark.parametrize(
("fixed_ornt", "moving_ornt", "ijk", "index"),
[
("RAS", "RAS", "i", 0),
("RAS", "RAS", "j", 1),
("RAS", "RAS", "k", 2),
("RAS", "PSL", "i", 1),
("RAS", "PSL", "j", 2),
("RAS", "PSL", "k", 0),
("PSL", "RAS", "i", 2),
("PSL", "RAS", "j", 0),
("PSL", "RAS", "k", 1),
],
)
def test_mm2vox(tmp_path, fixed_ornt, moving_ornt, ijk, index):
fixed_path = tmp_path / "fixed.nii.gz"
moving_path = tmp_path / "moving.nii.gz"

# Use separate zooms to make identifying the conversion easier
fixed_aff = np.diag((2, 3, 4, 1))
nb.save(
nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(fixed_ornt) @ fixed_aff),
fixed_path,
)
nb.save(
nb.Nifti1Image(np.zeros((10, 10, 10)), axcodes2aff(moving_ornt)),
moving_path,
)

config = json.loads(data.load.readable("sd_syn.json").read_text())

params = config["transform_parameters"]
mm_values = np.array([level[2] for level in params])

vox_params = _mm2vox(str(moving_path), str(fixed_path), ijk, config)
vox_values = [level[2] for level in vox_params]
assert [
mm_level[:2] == vox_level[:2] for mm_level, vox_level in zip(params, vox_params)
]
assert np.array_equal(vox_values, mm_values / [2, 3, 4][index])
Loading