Skip to content

Commit 9e86d93

Browse files
dPysoesteban
authored andcommitted
Add test for affine_registration and add interfaces to use for emc
1 parent 3de1201 commit 9e86d93

File tree

3 files changed

+203
-0
lines changed

3 files changed

+203
-0
lines changed

dmriprep/interfaces/registration.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Register tools interfaces."""
2+
import numpy as np
3+
import nibabel as nb
4+
import dmriprep
5+
from nipype import logging
6+
from pathlib import Path
7+
from nipype.utils.filemanip import fname_presuffix
8+
from nipype.interfaces.base import (
9+
traits,
10+
TraitedSpec,
11+
BaseInterfaceInputSpec,
12+
InputMultiObject,
13+
SimpleInterface,
14+
File,
15+
)
16+
17+
18+
LOGGER = logging.getLogger("nipype.interface")
19+
20+
21+
class _ApplyAffineInputSpec(BaseInterfaceInputSpec):
22+
moving_image = File(
23+
exists=True, mandatory=True, desc="image to apply transformation from"
24+
)
25+
fixed_image = File(
26+
exists=True, mandatory=True, desc="image to apply transformation to"
27+
)
28+
transform_affine = InputMultiObject(
29+
File(exists=True), mandatory=True, desc="transformation affine"
30+
)
31+
invert_transform = traits.Bool(False, usedefault=True)
32+
33+
34+
class _ApplyAffineOutputSpec(TraitedSpec):
35+
warped_image = File(exists=True, desc="Outputs warped image")
36+
37+
38+
class ApplyAffine(SimpleInterface):
39+
"""
40+
Interface to apply an affine transformation to an image.
41+
"""
42+
43+
input_spec = _ApplyAffineInputSpec
44+
output_spec = _ApplyAffineOutputSpec
45+
46+
def _run_interface(self, runtime):
47+
from dmriprep.utils.registration import apply_affine
48+
49+
warped_image_nifti = apply_affine(
50+
nb.load(self.inputs.moving_image),
51+
nb.load(self.inputs.fixed_image),
52+
np.load(self.inputs.transform_affine[0]),
53+
self.inputs.invert_transform,
54+
)
55+
cwd = Path(runtime.cwd).absolute()
56+
warped_file = fname_presuffix(
57+
self.inputs.moving_image,
58+
use_ext=False,
59+
suffix="_warped.nii.gz",
60+
newpath=str(cwd),
61+
)
62+
63+
warped_image_nifti.to_filename(warped_file)
64+
65+
self._results["warped_image"] = warped_file
66+
return runtime
67+
68+
69+
class _RegisterInputSpec(BaseInterfaceInputSpec):
70+
moving_image = File(
71+
exists=True, mandatory=True, desc="image to apply transformation from"
72+
)
73+
fixed_image = File(
74+
exists=True, mandatory=True, desc="image to apply transformation to"
75+
)
76+
nbins = traits.Int(default_value=32, usedefault=True)
77+
sampling_prop = traits.Float(defualt_value=1, usedefault=True)
78+
metric = traits.Str(default_value="MI", usedefault=True)
79+
level_iters = traits.List(
80+
trait=traits.Any(), value=[10000, 1000, 100], usedefault=True
81+
)
82+
sigmas = traits.List(trait=traits.Any(), value=[5.0, 2.5, 0.0], usedefault=True)
83+
factors = traits.List(trait=traits.Any(), value=[4, 2, 1], usedefault=True)
84+
params0 = traits.ArrayOrNone(value=None, usedefault=True)
85+
pipeline = traits.List(
86+
trait=traits.Any(),
87+
value=["c_of_mass", "translation", "rigid", "affine"],
88+
usedefault=True,
89+
)
90+
91+
92+
class _RegisterOutputSpec(TraitedSpec):
93+
forward_transforms = traits.List(
94+
File(exists=True), desc="List of output transforms for forward registration"
95+
)
96+
warped_image = File(exists=True, desc="Outputs warped image")
97+
98+
99+
class Register(SimpleInterface):
100+
"""
101+
Interface to perform affine registration.
102+
"""
103+
104+
input_spec = _RegisterInputSpec
105+
output_spec = _RegisterOutputSpec
106+
107+
def _run_interface(self, runtime):
108+
from dmriprep.utils.registration import affine_registration
109+
110+
reg_types = ["c_of_mass", "translation", "rigid", "affine"]
111+
pipeline = [
112+
getattr(dmriprep.utils.register, i)
113+
for i in self.inputs.pipeline
114+
if i in reg_types
115+
]
116+
117+
warped_image_nifti, forward_transform_mat = affine_registration(
118+
nb.load(self.inputs.moving_image),
119+
nb.load(self.inputs.fixed_image),
120+
self.inputs.nbins,
121+
self.inputs.sampling_prop,
122+
self.inputs.metric,
123+
pipeline,
124+
self.inputs.level_iters,
125+
self.inputs.sigmas,
126+
self.inputs.factors,
127+
self.inputs.params0,
128+
)
129+
cwd = Path(runtime.cwd).absolute()
130+
warped_file = fname_presuffix(
131+
self.inputs.moving_image,
132+
use_ext=False,
133+
suffix="_warped.nii.gz",
134+
newpath=str(cwd),
135+
)
136+
forward_transform_file = fname_presuffix(
137+
self.inputs.moving_image,
138+
use_ext=False,
139+
suffix="_forward_transform.npy",
140+
newpath=str(cwd),
141+
)
142+
warped_image_nifti.to_filename(warped_file)
143+
144+
np.save(forward_transform_file, forward_transform_mat)
145+
self._results["warped_image"] = warped_file
146+
self._results["forward_transforms"] = [forward_transform_file]
147+
return runtime
File renamed without changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
import numpy as np
3+
import numpy.testing as npt
4+
import nibabel as nb
5+
import dipy.data as dpd
6+
from dmriprep.utils.registration import affine_registration
7+
8+
9+
def setup_module():
10+
global subset_b0, subset_dwi_data, subset_t2, subset_b0_img, \
11+
subset_t2_img, gtab, hardi_affine, MNI_T2_affine
12+
MNI_T2 = dpd.read_mni_template()
13+
hardi_img, gtab = dpd.read_stanford_hardi()
14+
MNI_T2_data = MNI_T2.get_fdata()
15+
MNI_T2_affine = MNI_T2.affine
16+
hardi_data = hardi_img.get_fdata()
17+
hardi_affine = hardi_img.affine
18+
b0 = hardi_data[..., gtab.b0s_mask]
19+
mean_b0 = np.mean(b0, -1)
20+
21+
# Select some arbitrary chunks of data so this goes quicker
22+
subset_b0 = mean_b0[40:50, 40:50, 40:50]
23+
subset_dwi_data = nb.Nifti1Image(hardi_data[40:50, 40:50, 40:50],
24+
hardi_affine)
25+
subset_t2 = MNI_T2_data[40:60, 40:60, 40:60]
26+
subset_b0_img = nb.Nifti1Image(subset_b0, hardi_affine)
27+
subset_t2_img = nb.Nifti1Image(subset_t2, MNI_T2_affine)
28+
29+
30+
def test_affine_registration():
31+
moving = subset_b0
32+
static = subset_b0
33+
moving_affine = static_affine = np.eye(4)
34+
xformed, affine = affine_registration(moving, static)
35+
# We don't ask for much:
36+
npt.assert_almost_equal(affine[:3, :3], np.eye(3), decimal=1)
37+
38+
with pytest.raises(ValueError):
39+
# For array input, must provide affines:
40+
xformed, affine = affine_registration(moving, static)
41+
42+
# If providing nifti image objects, don't need to provide affines:
43+
moving_img = nb.Nifti1Image(moving, moving_affine)
44+
static_img = nb.Nifti1Image(static, static_affine)
45+
xformed, affine = affine_registration(moving_img, static_img)
46+
npt.assert_almost_equal(affine[:3, :3], np.eye(3), decimal=1)
47+
48+
# Using strings with full paths as inputs also works:
49+
t1_name, b0_name = dpd.get_fnames('syn_data')
50+
moving = b0_name
51+
static = t1_name
52+
xformed, affine = affine_registration(moving, static,
53+
level_iters=[5, 5],
54+
sigmas=[3, 1],
55+
factors=[4, 2])
56+
npt.assert_almost_equal(affine[:3, :3], np.eye(3), decimal=1)

0 commit comments

Comments
 (0)