Skip to content

Commit 44c4966

Browse files
committed
feat: add test reading a real h5 generated by ANTs
1 parent a3372bc commit 44c4966

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

nitransforms/io/itk.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def __init__(self, parameters=None, offset=None):
3030
"""Initialize with default offset and index."""
3131
super().__init__()
3232
self.structarr['index'] = 0
33-
self.structarr['offset'] = offset or [0, 0, 0]
33+
if offset is None:
34+
offset = np.zeros((3,), dtype='float')
35+
self.structarr['offset'] = offset
3436
self.structarr['parameters'] = np.eye(4)
3537
if parameters is not None:
3638
self.structarr['parameters'] = parameters
@@ -307,24 +309,25 @@ def from_h5obj(cls, fileobj, check=True):
307309
except KeyError:
308310
typo_fallback = "Tranform"
309311

310-
for xfm in reversed(h5group.keys())[:-1]:
311-
if h5group[xfm]["TransformType"][0].startswith(b"AffineTransform"):
312+
for xfm in reversed(list(h5group.values())[1:]):
313+
if xfm["TransformType"][0].startswith(b"AffineTransform"):
314+
_params = np.asanyarray(xfm[f"{typo_fallback}Parameters"])
312315
xfm_list.append(
313316
ITKLinearTransform(
314-
parameters=np.asanyarray(h5group[xfm][f"{typo_fallback}Parameters"]),
315-
offset=np.asanyarray(h5group[xfm][f"{typo_fallback}FixedParameters"])
317+
parameters=from_matvec(_params[:-3].reshape(3, 3), _params[-3:]),
318+
offset=np.asanyarray(xfm[f"{typo_fallback}FixedParameters"])
316319
)
317320
)
318321
continue
319-
if h5group[xfm]["TransformType"][0].startswith(b"DisplacementFieldTransform"):
320-
_fixed = np.asanyarray(h5group[xfm][f"{typo_fallback}FixedParameters"])
322+
if xfm["TransformType"][0].startswith(b"DisplacementFieldTransform"):
323+
_fixed = np.asanyarray(xfm[f"{typo_fallback}FixedParameters"])
321324
shape = _fixed[:3].astype('uint16').tolist()
322325
offset = _fixed[3:6].astype('uint16')
323326
zooms = _fixed[6:9].astype('float')
324327
directions = _fixed[9:].astype('float').reshape((3, 3))
325328
affine = from_matvec(directions * zooms, offset)
326-
field = np.asanyarray(h5group[xfm][f"{typo_fallback}Parameters"]).reshape(
327-
tuple(shape + [-1])
329+
field = np.asanyarray(xfm[f"{typo_fallback}Parameters"]).reshape(
330+
tuple(shape + [1, -1])
328331
)
329332
hdr = Nifti1Header()
330333
hdr.set_intent("vector")
@@ -338,7 +341,7 @@ def from_h5obj(cls, fileobj, check=True):
338341
continue
339342

340343
raise NotImplementedError(
341-
f"Unsupported transform type {h5group[xfm]['TransformType'][0]}"
344+
f"Unsupported transform type {xfm['TransformType'][0]}"
342345
)
343346

344347
return xfm_list

nitransforms/tests/test_io.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -324,22 +324,26 @@ def _mockreturn(arg):
324324
_read_mat(f)
325325

326326

327-
@pytest.mark.parametrize('sw_tool', ['afni'])
328-
def test_Displacements(sw_tool):
327+
def test_afni_Displacements():
329328
"""Test displacements fields."""
329+
field = nb.Nifti1Image(np.zeros((10, 10, 10)), None, None)
330+
with pytest.raises(TransformFileError):
331+
afni.AFNIDisplacementsField.from_image(field)
330332

331-
if sw_tool == 'afni':
332-
field = nb.Nifti1Image(np.zeros((10, 10, 10)), None, None)
333-
with pytest.raises(TransformFileError):
334-
afni.AFNIDisplacementsField.from_image(field)
333+
field = nb.Nifti1Image(np.zeros((10, 10, 10, 2, 3)), None, None)
334+
with pytest.raises(TransformFileError):
335+
afni.AFNIDisplacementsField.from_image(field)
335336

336-
field = nb.Nifti1Image(np.zeros((10, 10, 10, 2, 3)), None, None)
337-
with pytest.raises(TransformFileError):
338-
afni.AFNIDisplacementsField.from_image(field)
337+
field = nb.Nifti1Image(np.zeros((10, 10, 10, 1, 4)), None, None)
338+
with pytest.raises(TransformFileError):
339+
afni.AFNIDisplacementsField.from_image(field)
339340

340-
field = nb.Nifti1Image(np.zeros((10, 10, 10, 1, 4)), None, None)
341-
with pytest.raises(TransformFileError):
342-
afni.AFNIDisplacementsField.from_image(field)
341+
342+
def test_itk_h5(data_path):
343+
"""Test displacements fields."""
344+
itk.ITKCompositeH5.from_filename(
345+
data_path / 'ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5'
346+
)
343347

344348

345349
@pytest.mark.parametrize('file_type, test_file', [

0 commit comments

Comments
 (0)