5
5
import os
6
6
from subprocess import check_call
7
7
import shutil
8
+
9
+ import SimpleITK as sitk
8
10
import pytest
9
11
10
12
import numpy as np
@@ -289,9 +291,15 @@ def test_constant_field_vs_ants(tmp_path):
289
291
field [..., 2 ] = 5
290
292
291
293
field_img = nb .Nifti1Image (field , ref_affine )
292
- itk_img = ITKDisplacementsField . to_image ( field_img )
294
+
293
295
warpfile = tmp_path / "const_disp.nii.gz"
294
- itk_img .to_filename (warpfile )
296
+ itk_img = sitk .GetImageFromArray (field , isVector = True )
297
+ itk_img .SetOrigin (tuple (ref_affine [:3 , 3 ]))
298
+ zooms = np .sqrt ((ref_affine [:3 , :3 ] ** 2 ).sum (0 ))
299
+ itk_img .SetSpacing (tuple (zooms ))
300
+ direction = (ref_affine [:3 , :3 ] / zooms ).ravel ()
301
+ itk_img .SetDirection (tuple (direction ))
302
+ sitk .WriteImage (itk_img , str (warpfile ))
295
303
296
304
points = np .array ([[0.0 , 0.0 , 0.0 ], [1.0 , 2.0 , 3.0 ]])
297
305
csvin = tmp_path / "points.csv"
@@ -307,7 +315,7 @@ def test_constant_field_vs_ants(tmp_path):
307
315
ants_res = np .genfromtxt (csvout , delimiter = "," , names = True )
308
316
ants_pts = np .vstack ([ants_res [n ] for n in ("x" , "y" , "z" )]).T
309
317
310
- xfm = DenseFieldTransform (warpfile )
318
+ xfm = DenseFieldTransform (field_img )
311
319
mapped = xfm .map (points )
312
320
313
321
assert not np .allclose (mapped , ants_pts , atol = 1e-6 )
0 commit comments