3
3
"""Tests of nonlinear transforms."""
4
4
5
5
import os
6
+ from subprocess import check_call
7
+ import shutil
8
+
9
+ import SimpleITK as sitk
6
10
import pytest
7
11
8
12
import numpy as np
9
13
import nibabel as nb
14
+ from nibabel .affines import from_matvec
10
15
from nitransforms .resampling import apply
11
16
from nitransforms .base import TransformError
12
17
from nitransforms .io .base import TransformFileError
15
20
DenseFieldTransform ,
16
21
)
17
22
from nitransforms import io
18
- from . .io .itk import ITKDisplacementsField
23
+ from nitransforms .io .itk import ITKDisplacementsField
19
24
20
25
21
26
@pytest .mark .parametrize ("size" , [(20 , 20 , 20 ), (20 , 20 , 20 , 3 )])
@@ -34,16 +39,6 @@ def test_displacements_bad_sizes(size):
34
39
DenseFieldTransform (nb .Nifti1Image (np .zeros (size ), np .eye (4 ), None ))
35
40
36
41
37
- def test_itk_disp_load_intent ():
38
- """Checks whether the NIfTI intent is fixed."""
39
- with pytest .warns (UserWarning ):
40
- field = ITKDisplacementsField .from_image (
41
- nb .Nifti1Image (np .zeros ((20 , 20 , 20 , 1 , 3 )), np .eye (4 ), None )
42
- )
43
-
44
- assert field .header .get_intent ()[0 ] == "vector"
45
-
46
-
47
42
def test_displacements_init ():
48
43
identity1 = DenseFieldTransform (
49
44
np .zeros ((10 , 10 , 10 , 3 )),
@@ -67,6 +62,30 @@ def test_displacements_init():
67
62
)
68
63
69
64
65
+ @pytest .mark .parametrize ("is_deltas" , [True , False ])
66
+ def test_densefield_oob_resampling (is_deltas ):
67
+ """Ensure mapping outside the field returns input coordinates."""
68
+ ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
69
+
70
+ if is_deltas :
71
+ field = nb .Nifti1Image (np .ones ((2 , 2 , 2 , 3 ), dtype = "float32" ), np .eye (4 ))
72
+ else :
73
+ grid = np .stack (
74
+ np .meshgrid (* [np .arange (2 ) for _ in range (3 )], indexing = "ij" ),
75
+ axis = - 1 ,
76
+ ).astype ("float32" )
77
+ field = nb .Nifti1Image (grid + 1.0 , np .eye (4 ))
78
+
79
+ xfm = DenseFieldTransform (field , is_deltas = is_deltas , reference = ref )
80
+
81
+ points = np .array ([[- 1.0 , - 1.0 , - 1.0 ], [0.5 , 0.5 , 0.5 ], [3.0 , 3.0 , 3.0 ]])
82
+ mapped = xfm .map (points )
83
+
84
+ assert np .allclose (mapped [0 ], points [0 ])
85
+ assert np .allclose (mapped [2 ], points [2 ])
86
+ assert np .allclose (mapped [1 ], points [1 ] + 1 )
87
+
88
+
70
89
def test_bsplines_init ():
71
90
with pytest .raises (TransformError ):
72
91
BSplineFieldTransform (
@@ -122,76 +141,6 @@ def test_bspline(tmp_path, testdata_path):
122
141
)
123
142
124
143
125
- @pytest .mark .parametrize ("is_deltas" , [True , False ])
126
- def test_densefield_x5_roundtrip (tmp_path , is_deltas ):
127
- """Ensure dense field transforms roundtrip via X5."""
128
- ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
129
- disp = nb .Nifti1Image (np .random .rand (2 , 2 , 2 , 3 ).astype ("float32" ), np .eye (4 ))
130
-
131
- xfm = DenseFieldTransform (disp , is_deltas = is_deltas , reference = ref )
132
-
133
- node = xfm .to_x5 (metadata = {"GeneratedBy" : "pytest" })
134
- assert node .type == "nonlinear"
135
- assert node .subtype == "densefield"
136
- assert node .representation == "displacements" if is_deltas else "deformations"
137
- assert node .domain .size == ref .shape
138
- assert node .metadata ["GeneratedBy" ] == "pytest"
139
-
140
- fname = tmp_path / "test.x5"
141
- io .x5 .to_filename (fname , [node ])
142
-
143
- xfm2 = DenseFieldTransform .from_filename (fname , fmt = "X5" )
144
-
145
- assert xfm2 .reference .shape == ref .shape
146
- assert np .allclose (xfm2 .reference .affine , ref .affine )
147
- assert xfm == xfm2
148
-
149
-
150
- def test_bspline_to_x5 (tmp_path ):
151
- """Check BSpline transforms export to X5."""
152
- coeff = nb .Nifti1Image (np .zeros ((2 , 2 , 2 , 3 ), dtype = "float32" ), np .eye (4 ))
153
- ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
154
-
155
- xfm = BSplineFieldTransform (coeff , reference = ref )
156
- node = xfm .to_x5 (metadata = {"tool" : "pytest" })
157
- assert node .type == "nonlinear"
158
- assert node .subtype == "bspline"
159
- assert node .representation == "coefficients"
160
- assert node .metadata ["tool" ] == "pytest"
161
-
162
- fname = tmp_path / "bspline.x5"
163
- io .x5 .to_filename (fname , [node ])
164
-
165
- xfm2 = BSplineFieldTransform .from_filename (fname , fmt = "X5" )
166
- assert np .allclose (xfm ._coeffs , xfm2 ._coeffs )
167
- assert xfm2 .reference .shape == ref .shape
168
- assert np .allclose (xfm2 .reference .affine , ref .affine )
169
-
170
-
171
- @pytest .mark .parametrize ("is_deltas" , [True , False ])
172
- def test_densefield_oob_resampling (is_deltas ):
173
- """Ensure mapping outside the field returns input coordinates."""
174
- ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
175
-
176
- if is_deltas :
177
- field = nb .Nifti1Image (np .ones ((2 , 2 , 2 , 3 ), dtype = "float32" ), np .eye (4 ))
178
- else :
179
- grid = np .stack (
180
- np .meshgrid (* [np .arange (2 ) for _ in range (3 )], indexing = "ij" ),
181
- axis = - 1 ,
182
- ).astype ("float32" )
183
- field = nb .Nifti1Image (grid + 1.0 , np .eye (4 ))
184
-
185
- xfm = DenseFieldTransform (field , is_deltas = is_deltas , reference = ref )
186
-
187
- points = np .array ([[- 1.0 , - 1.0 , - 1.0 ], [0.5 , 0.5 , 0.5 ], [3.0 , 3.0 , 3.0 ]])
188
- mapped = xfm .map (points )
189
-
190
- assert np .allclose (mapped [0 ], points [0 ])
191
- assert np .allclose (mapped [2 ], points [2 ])
192
- assert np .allclose (mapped [1 ], points [1 ] + 1 )
193
-
194
-
195
144
def test_bspline_map_gridpoints ():
196
145
"""BSpline mapping matches dense field on grid points."""
197
146
ref = nb .Nifti1Image (np .zeros ((5 , 5 , 5 ), dtype = "uint8" ), np .eye (4 ))
@@ -243,3 +192,128 @@ def manual_map(x):
243
192
pts = np .array ([[1.2 , 1.5 , 2.0 ], [3.3 , 1.7 , 2.4 ]])
244
193
expected = np .vstack ([manual_map (p ) for p in pts ])
245
194
assert np .allclose (bspline .map (pts ), expected , atol = 1e-6 )
195
+
196
+
197
+ def test_densefield_map_against_ants (testdata_path , tmp_path ):
198
+ """Map points with DenseFieldTransform and compare to ANTs."""
199
+ warpfile = (
200
+ testdata_path
201
+ / "regressions"
202
+ / ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz" )
203
+ )
204
+ if not warpfile .exists ():
205
+ pytest .skip ("Composite transform test data not available" )
206
+
207
+ points = np .array (
208
+ [
209
+ [0.0 , 0.0 , 0.0 ],
210
+ [1.0 , 2.0 , 3.0 ],
211
+ [10.0 , - 10.0 , 5.0 ],
212
+ [- 5.0 , 7.0 , - 2.0 ],
213
+ [- 12.0 , 12.0 , 0.0 ],
214
+ ]
215
+ )
216
+ csvin = tmp_path / "points.csv"
217
+ np .savetxt (csvin , points , delimiter = "," , header = "x,y,z" , comments = "" )
218
+
219
+ csvout = tmp_path / "out.csv"
220
+ cmd = f"antsApplyTransformsToPoints -d 3 -i { csvin } -o { csvout } -t { warpfile } "
221
+ exe = cmd .split ()[0 ]
222
+ if not shutil .which (exe ):
223
+ pytest .skip (f"Command { exe } not found on host" )
224
+ check_call (cmd , shell = True )
225
+
226
+ ants_res = np .genfromtxt (csvout , delimiter = "," , names = True )
227
+ ants_pts = np .vstack ([ants_res [n ] for n in ("x" , "y" , "z" )]).T
228
+
229
+ xfm = DenseFieldTransform (ITKDisplacementsField .from_filename (warpfile ))
230
+ mapped = xfm .map (points )
231
+
232
+ assert np .allclose (mapped , ants_pts , atol = 1e-6 )
233
+
234
+
235
+ @pytest .mark .parametrize ("image_orientation" , ["RAS" , "LAS" , "LPS" , "oblique" ])
236
+ @pytest .mark .parametrize ("gridpoints" , [True , False ])
237
+ def test_constant_field_vs_ants (tmp_path , get_testdata , image_orientation , gridpoints ):
238
+ """Create a constant displacement field and compare mappings."""
239
+
240
+ nii = get_testdata [image_orientation ]
241
+
242
+ # Create a reference centered at the origin with various axis orders/flips
243
+ shape = nii .shape
244
+ ref_affine = nii .affine .copy ()
245
+
246
+ field = np .hstack ((
247
+ np .zeros (np .prod (shape )),
248
+ np .linspace (- 80 , 80 , num = np .prod (shape )),
249
+ np .linspace (- 50 , 50 , num = np .prod (shape )),
250
+ )).reshape (shape + (3 , ))
251
+ fieldnii = nb .Nifti1Image (field , ref_affine , None )
252
+
253
+ warpfile = tmp_path / "itk_transform.nii.gz"
254
+ ITKDisplacementsField .to_filename (fieldnii , warpfile )
255
+
256
+ # Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
257
+ xfm = DenseFieldTransform (fieldnii )
258
+ itk_xfm = DenseFieldTransform (ITKDisplacementsField .from_filename (warpfile ))
259
+
260
+ assert xfm == itk_xfm
261
+ np .testing .assert_allclose (xfm .reference .affine , itk_xfm .reference .affine )
262
+ np .testing .assert_allclose (ref_affine , itk_xfm .reference .affine )
263
+ np .testing .assert_allclose (xfm .reference .shape , itk_xfm .reference .shape )
264
+ np .testing .assert_allclose (xfm ._field , itk_xfm ._field )
265
+
266
+ points = (
267
+ xfm .reference .ndcoords .T if gridpoints
268
+ else np .array (
269
+ [
270
+ [0.0 , 0.0 , 0.0 ],
271
+ [1.0 , 2.0 , 3.0 ],
272
+ [10.0 , - 10.0 , 5.0 ],
273
+ [- 5.0 , 7.0 , - 2.0 ],
274
+ [12.0 , 0.0 , - 11.0 ],
275
+ ]
276
+ )
277
+ )
278
+
279
+ mapped = xfm .map (points )
280
+ nit_deltas = mapped - points
281
+
282
+ if gridpoints :
283
+ np .testing .assert_array_equal (field , nit_deltas .reshape (* shape , - 1 ))
284
+
285
+ csvin = tmp_path / "points.csv"
286
+ np .savetxt (csvin , points , delimiter = "," , header = "x,y,z" , comments = "" )
287
+
288
+ csvout = tmp_path / "out.csv"
289
+ cmd = f"antsApplyTransformsToPoints -d 3 -i { csvin } -o { csvout } -t { warpfile } "
290
+ exe = cmd .split ()[0 ]
291
+ if not shutil .which (exe ):
292
+ pytest .skip (f"Command { exe } not found on host" )
293
+ check_call (cmd , shell = True )
294
+
295
+ ants_res = np .genfromtxt (csvout , delimiter = "," , names = True )
296
+ ants_pts = np .vstack ([ants_res [n ] for n in ("x" , "y" , "z" )]).T
297
+
298
+ # if gridpoints:
299
+ # ants_field = ants_pts.reshape(shape + (3, ))
300
+ # diff = xfm._field[..., 0] - ants_field[..., 0]
301
+ # mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
302
+ # assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
303
+
304
+ # diff = xfm._field[..., 1] - ants_field[..., 1]
305
+ # mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
306
+ # assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
307
+
308
+ # diff = xfm._field[..., 2] - ants_field[..., 2]
309
+ # mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
310
+ # assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
311
+
312
+ ants_deltas = ants_pts - points
313
+ np .testing .assert_array_equal (nit_deltas , ants_deltas )
314
+ np .testing .assert_array_equal (mapped , ants_pts )
315
+
316
+ diff = mapped - ants_pts
317
+ mask = np .argwhere (np .abs (diff ) > 1e-2 )[:, 0 ]
318
+
319
+ assert len (mask ) == 0 , f"A total of { len (mask )} /{ ants_pts .shape [0 ]} contained errors:\n { diff [mask ]} "
0 commit comments