Skip to content

Commit bd03cce

Browse files
authored
Merge pull request #50 from smoia/master
ENH: First implementation of AFNI displacement fields
2 parents ed7e8ad + 845f5b0 commit bd03cce

File tree

7 files changed

+106
-28
lines changed

7 files changed

+106
-28
lines changed

nitransforms/io/afni.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from nibabel.affines import obliquity, voxel_sizes
55

66
from ..patched import shape_zoom_affine
7-
from .base import BaseLinearTransformList, LinearParameters, TransformFileError
7+
from .base import (
8+
BaseLinearTransformList,
9+
DisplacementsField,
10+
LinearParameters,
11+
TransformFileError,
12+
)
813

914
LPS = np.diag([-1, -1, 1, 1])
1015
OBLIQUITY_THRESHOLD_DEG = 0.01
@@ -119,5 +124,29 @@ def from_string(cls, string):
119124
return _self
120125

121126

127+
class AFNIDisplacementsField(DisplacementsField):
128+
"""A data structure representing displacements fields."""
129+
130+
@classmethod
131+
def from_image(cls, imgobj):
132+
"""Import a displacements field from a NIfTI file."""
133+
hdr = imgobj.header.copy()
134+
shape = hdr.get_data_shape()
135+
136+
if (
137+
len(shape) != 5 or
138+
shape[-2] != 1 or
139+
not shape[-1] in (2, 3)
140+
):
141+
raise TransformFileError(
142+
'Displacements field "%s" does not come from AFNI.' %
143+
imgobj.file_map['image'].filename)
144+
145+
field = np.squeeze(np.asanyarray(imgobj.dataobj))
146+
field[..., (0, 1)] *= -1.0
147+
148+
return imgobj.__class__(field, imgobj.affine, hdr)
149+
150+
122151
def _is_oblique(affine, thres=OBLIQUITY_THRESHOLD_DEG):
123152
return (obliquity(affine).min() * 180 / pi) > thres

nitransforms/io/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Read/write linear transforms."""
22
import numpy as np
3+
from nibabel import load as loadimg
34
from scipy.io.matlab.miobase import get_matfile_version
45
from scipy.io.matlab.mio4 import MatFile4Reader
56
from scipy.io.matlab.mio5 import MatFile5Reader
@@ -157,6 +158,21 @@ def from_string(cls, string):
157158
raise NotImplementedError
158159

159160

161+
class DisplacementsField:
162+
"""A data structure representing displacements fields."""
163+
164+
@classmethod
165+
def from_filename(cls, filename):
166+
"""Import a displacements field from a NIfTI file."""
167+
imgobj = loadimg(str(filename))
168+
return cls.from_image(imgobj)
169+
170+
@classmethod
171+
def from_image(cls, imgobj):
172+
"""Import a displacements field from a nibabel image object."""
173+
raise NotImplementedError
174+
175+
160176
def _read_mat(byte_stream):
161177
mjv, _ = get_matfile_version(byte_stream)
162178
if mjv == 0:

nitransforms/io/itk.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from scipy.io import savemat as _save_mat
55
from nibabel.loadsave import load as loadimg
66
from nibabel.affines import from_matvec
7-
from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError
7+
from .base import (
8+
BaseLinearTransformList,
9+
DisplacementsField,
10+
LinearParameters,
11+
TransformFileError,
12+
_read_mat,
13+
)
814

915
LPS = np.diag([-1, -1, 1, 1])
1016

@@ -249,35 +255,29 @@ def from_string(cls, string):
249255
return _self
250256

251257

252-
class ITKDisplacementsField:
258+
class ITKDisplacementsField(DisplacementsField):
253259
"""A data structure representing displacements fields."""
254260

255-
@classmethod
256-
def from_filename(cls, filename):
257-
"""Import a displacements field from a NIfTI file."""
258-
imgobj = loadimg(str(filename))
259-
return cls.from_image(imgobj)
260-
261261
@classmethod
262262
def from_image(cls, imgobj):
263263
"""Import a displacements field from a NIfTI file."""
264-
_hdr = imgobj.header.copy()
265-
_shape = _hdr.get_data_shape()
264+
hdr = imgobj.header.copy()
265+
shape = hdr.get_data_shape()
266266

267267
if (
268-
len(_shape) != 5 or
269-
_shape[-2] != 1 or
270-
not _shape[-1] in (2, 3)
268+
len(shape) != 5 or
269+
shape[-2] != 1 or
270+
not shape[-1] in (2, 3)
271271
):
272272
raise TransformFileError(
273273
'Displacements field "%s" does not come from ITK.' %
274274
imgobj.file_map['image'].filename)
275275

276-
if _hdr.get_intent()[0] != 'vector':
276+
if hdr.get_intent()[0] != 'vector':
277277
warnings.warn('Incorrect intent identified.')
278-
_hdr.set_intent('vector')
278+
hdr.set_intent('vector')
279279

280-
_field = np.squeeze(np.asanyarray(imgobj.dataobj))
281-
_field[..., (0, 1)] *= -1.0
280+
field = np.squeeze(np.asanyarray(imgobj.dataobj))
281+
field[..., (0, 1)] *= -1.0
282282

283-
return imgobj.__class__(_field, imgobj.affine, _hdr)
283+
return imgobj.__class__(field, imgobj.affine, hdr)
Binary file not shown.

nitransforms/tests/test_io.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import filecmp
8+
import nibabel as nb
89
from nibabel.eulerangles import euler2mat
910
from nibabel.affines import from_matvec
1011
from scipy.io import loadmat, savemat
@@ -321,3 +322,20 @@ def _mockreturn(arg):
321322
with pytest.raises(TransformFileError):
322323
with open('val.mat', 'rb') as f:
323324
_read_mat(f)
325+
326+
@pytest.mark.parametrize('sw_tool', ['afni'])
327+
def test_Displacements(sw_tool):
328+
"""Test displacements fields."""
329+
330+
if sw_tool == 'afni':
331+
field = nb.Nifti1Image(np.zeros((10, 10, 10)), None, None)
332+
with pytest.raises(TransformFileError):
333+
afni.AFNIDisplacementsField.from_image(field)
334+
335+
field = nb.Nifti1Image(np.zeros((10, 10, 10, 2, 3)), None, None)
336+
with pytest.raises(TransformFileError):
337+
afni.AFNIDisplacementsField.from_image(field)
338+
339+
field = nb.Nifti1Image(np.zeros((10, 10, 10, 1, 4)), None, None)
340+
with pytest.raises(TransformFileError):
341+
afni.AFNIDisplacementsField.from_image(field)

nitransforms/tests/test_nonlinear.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
from ..io.base import TransformFileError
1212
from ..nonlinear import DisplacementsFieldTransform
1313
from ..io.itk import ITKDisplacementsField
14+
from ..io.afni import AFNIDisplacementsField
1415

1516
TESTS_BORDER_TOLERANCE = 0.05
1617
APPLY_NONLINEAR_CMD = {
1718
'itk': """\
1819
antsApplyTransforms -d 3 -r {reference} -i {moving} \
1920
-o resampled.nii.gz -n NearestNeighbor -t {transform} --float\
21+
""".format,
22+
'afni': """\
23+
3dNwarpApply -nwarp {transform} -source {moving} \
24+
-master {reference} -interp NN -prefix resampled.nii.gz
2025
""".format,
2126
}
2227

@@ -46,8 +51,9 @@ def test_itk_disp_load_intent():
4651
assert field.header.get_intent()[0] == 'vector'
4752

4853

54+
@pytest.mark.xfail(reason="Oblique datasets not fully implemented")
4955
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
50-
@pytest.mark.parametrize('sw_tool', ['itk'])
56+
@pytest.mark.parametrize('sw_tool', ['itk', 'afni'])
5157
@pytest.mark.parametrize('axis', [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
5258
def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis):
5359
"""Check a translation-only field on one or more axes, different image orientations."""
@@ -58,15 +64,20 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool
5864
fieldmap[..., axis] = -10.0
5965

6066
_hdr = nii.header.copy()
61-
_hdr.set_intent('vector')
67+
if sw_tool in ('itk', ):
68+
_hdr.set_intent('vector')
6269
_hdr.set_data_dtype('float32')
6370

6471
xfm_fname = 'warp.nii.gz'
6572
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
6673
field.to_filename(xfm_fname)
6774

68-
xfm = DisplacementsFieldTransform(
69-
ITKDisplacementsField.from_image(field))
75+
if sw_tool == 'itk':
76+
xfm = DisplacementsFieldTransform(
77+
ITKDisplacementsField.from_image(field))
78+
elif sw_tool == 'afni':
79+
xfm = DisplacementsFieldTransform(
80+
AFNIDisplacementsField.from_image(field))
7081

7182
# Then apply the transform and cross-check with software
7283
cmd = APPLY_NONLINEAR_CMD[sw_tool](
@@ -90,15 +101,19 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool
90101
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
91102

92103

93-
@pytest.mark.parametrize('sw_tool', ['itk'])
104+
@pytest.mark.parametrize('sw_tool', ['itk', 'afni'])
94105
def test_displacements_field2(tmp_path, data_path, sw_tool):
95106
"""Check a translation-only field on one or more axes, different image orientations."""
96107
os.chdir(str(tmp_path))
97108
img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz'
98-
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz'
99-
100-
xfm = DisplacementsFieldTransform(
101-
ITKDisplacementsField.from_filename(xfm_fname))
109+
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp_{}.nii.gz'.format(sw_tool)
110+
111+
if sw_tool == 'itk':
112+
xfm = DisplacementsFieldTransform(
113+
ITKDisplacementsField.from_filename(xfm_fname))
114+
elif sw_tool == 'afni':
115+
xfm = DisplacementsFieldTransform(
116+
AFNIDisplacementsField.from_filename(xfm_fname))
102117

103118
# Then apply the transform and cross-check with software
104119
cmd = APPLY_NONLINEAR_CMD[sw_tool](

0 commit comments

Comments
 (0)