Skip to content

Commit 6dff850

Browse files
committed
enh: initialize AFNI's displacements fields loader
1 parent ed7e8ad commit 6dff850

File tree

4 files changed

+55
-13
lines changed

4 files changed

+55
-13
lines changed

nitransforms/io/afni.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from nibabel.affines import obliquity, voxel_sizes
55

66
from ..patched import shape_zoom_affine
7-
from .base import BaseLinearTransformList, LinearParameters, TransformFileError
7+
from .base import (
8+
BaseLinearTransformList,
9+
DisplacementsField,
10+
LinearParameters,
11+
TransformFileError,
12+
)
813

914
LPS = np.diag([-1, -1, 1, 1])
1015
OBLIQUITY_THRESHOLD_DEG = 0.01
@@ -119,5 +124,29 @@ def from_string(cls, string):
119124
return _self
120125

121126

127+
class AFNIDisplacementsField(DisplacementsField):
128+
"""A data structure representing displacements fields."""
129+
130+
@classmethod
131+
def from_image(cls, imgobj):
132+
"""Import a displacements field from a NIfTI file."""
133+
_hdr = imgobj.header.copy()
134+
_shape = _hdr.get_data_shape()
135+
136+
if (
137+
len(_shape) != 5 or
138+
_shape[-2] != 1 or
139+
not _shape[-1] in (2, 3)
140+
):
141+
raise TransformFileError(
142+
'Displacements field "%s" does not come from AFNI.' %
143+
imgobj.file_map['image'].filename)
144+
145+
_field = np.squeeze(np.asanyarray(imgobj.dataobj))
146+
_field[..., (0, 1)] *= -1.0
147+
148+
return imgobj.__class__(_field, imgobj.affine, _hdr)
149+
150+
122151
def _is_oblique(affine, thres=OBLIQUITY_THRESHOLD_DEG):
123152
return (obliquity(affine).min() * 180 / pi) > thres

nitransforms/io/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,21 @@ def from_string(cls, string):
157157
raise NotImplementedError
158158

159159

160+
class DisplacementsField:
161+
"""A data structure representing displacements fields."""
162+
163+
@classmethod
164+
def from_filename(cls, filename):
165+
"""Import a displacements field from a NIfTI file."""
166+
imgobj = loadimg(str(filename))
167+
return cls.from_image(imgobj)
168+
169+
@classmethod
170+
def from_image(cls, imgobj):
171+
"""Import a displacements field from a nibabel image object."""
172+
raise NotImplementedError
173+
174+
160175
def _read_mat(byte_stream):
161176
mjv, _ = get_matfile_version(byte_stream)
162177
if mjv == 0:

nitransforms/io/itk.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from scipy.io import savemat as _save_mat
55
from nibabel.loadsave import load as loadimg
66
from nibabel.affines import from_matvec
7-
from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError
7+
from .base import (
8+
BaseLinearTransformList,
9+
DisplacementsField,
10+
LinearParameters,
11+
TransformFileError,
12+
_read_mat,
13+
)
814

915
LPS = np.diag([-1, -1, 1, 1])
1016

@@ -249,15 +255,9 @@ def from_string(cls, string):
249255
return _self
250256

251257

252-
class ITKDisplacementsField:
258+
class ITKDisplacementsField(DisplacementsField):
253259
"""A data structure representing displacements fields."""
254260

255-
@classmethod
256-
def from_filename(cls, filename):
257-
"""Import a displacements field from a NIfTI file."""
258-
imgobj = loadimg(str(filename))
259-
return cls.from_image(imgobj)
260-
261261
@classmethod
262262
def from_image(cls, imgobj):
263263
"""Import a displacements field from a NIfTI file."""

nitransforms/tests/test_nonlinear.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,14 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool
5858
fieldmap[..., axis] = -10.0
5959

6060
_hdr = nii.header.copy()
61-
_hdr.set_intent('vector')
61+
if sw_tool in ('itk', ):
62+
_hdr.set_intent('vector')
6263
_hdr.set_data_dtype('float32')
6364

6465
xfm_fname = 'warp.nii.gz'
6566
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
6667
field.to_filename(xfm_fname)
6768

68-
xfm = DisplacementsFieldTransform(
69-
ITKDisplacementsField.from_image(field))
70-
7169
# Then apply the transform and cross-check with software
7270
cmd = APPLY_NONLINEAR_CMD[sw_tool](
7371
transform=os.path.abspath(xfm_fname),

0 commit comments

Comments
 (0)