Skip to content

Commit 95a215e

Browse files
Julien Marabottooesteban
authored andcommitted
Updated: offsource apply
1 parent c5b86e1 commit 95a215e

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

nitransforms/nonlinear.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from nitransforms import io
1515
from nitransforms.io.base import _ensure_image
1616
from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline
17+
from nitransforms.resampling import apply
1718
from nitransforms.base import (
1819
TransformBase,
1920
TransformError,
@@ -257,7 +258,7 @@ def __init__(self, coefficients, reference=None, order=3):
257258
if reference is not None:
258259
self.reference = reference
259260

260-
if coefficients.shape[-1] != self.ndim:
261+
if coefficients.shape[-1] != self.reference.ndim:
261262
raise TransformError(
262263
'Number of components of the coefficients does '
263264
'not match the number of dimensions')
@@ -310,19 +311,23 @@ def apply(
310311
spatialimage = _ensure_image(spatialimage)
311312

312313
# If locations to be interpolated are not on a grid, run map()
314+
#import pdb; pdb.set_trace()
313315
if not isinstance(_ref, ImageGrid):
314-
return super().apply(
316+
return apply(
317+
super(),
315318
spatialimage,
316319
reference=_ref,
320+
output_dtype=output_dtype,
317321
order=order,
318322
mode=mode,
319323
cval=cval,
320324
prefilter=prefilter,
321-
output_dtype=output_dtype,
325+
322326
)
323327

324328
# If locations to be interpolated are on a grid, generate a displacements field
325-
return self.to_field(reference=reference).apply(
329+
return apply(
330+
self.to_field(reference=reference),
326331
spatialimage,
327332
reference=reference,
328333
order=order,

nitransforms/resampling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,12 @@ def apply(
9898
if data.ndim < transform.ndim:
9999
data = data[..., np.newaxis]
100100

101-
import pdb; pdb.set_trace()
101+
if transform.ndim == 4:
102+
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
103+
102104
resampled = ndi.map_coordinates(
103105
data,
104-
_as_homogeneous(targets.reshape(-2, targets.shape[0])).T,
106+
targets,
105107
output=output_dtype,
106108
order=order,
107109
mode=mode,

nitransforms/tests/test_base.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,14 @@ def _to_hdf5(klass, x5_root):
9494
img = nb.load(fname)
9595
imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype())
9696

97+
# Test identity transform - setting reference
9798
xfm = TransformBase()
9899
with pytest.raises(TypeError):
99100
_ = xfm.ndim
100101

102+
# Test to_filename
103+
xfm.to_filename("data.x5")
104+
101105
# Test identity transform
102106
xfm = nitl.Affine()
103107
xfm.reference = fname
@@ -106,17 +110,6 @@ def _to_hdf5(klass, x5_root):
106110
imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())
107111
)
108112

109-
# Test identity transform - setting reference
110-
xfm = TransformBase()
111-
xfm.reference = fname
112-
113-
with pytest.raises(TypeError):
114-
_ = xfm.ndim
115-
moved = apply(xfm, str(fname), reference=fname, order=0)
116-
assert np.all(
117-
imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())
118-
)
119-
120113
# Test ndim returned by affine
121114
assert nitl.Affine().ndim == 3
122115
assert nitl.LinearTransformsMapping(
@@ -136,7 +129,7 @@ def _to_hdf5(klass, x5_root):
136129
assert np.allclose(giimoved.reshape(xfm.reference.shape), moved.get_fdata())
137130

138131
# Test to_filename
139-
xfm.to_filename("data.x5")
132+
xfm.to_filename("data.xfm", fmt='itk')
140133

141134

142135
def test_SampledSpatialData(testdata_path):

nitransforms/tests/test_nonlinear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import nibabel as nb
11+
from nitransforms.resampling import apply
1112
from nitransforms.base import TransformError
1213
from nitransforms.io.base import TransformFileError
1314
from nitransforms.nonlinear import (
@@ -247,8 +248,8 @@ def test_bspline(tmp_path, testdata_path):
247248
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
248249
dispxfm = DenseFieldTransform(disp_name)
249250

250-
out_disp = dispxfm.apply(img_name)
251-
out_bspl = bsplxfm.apply(img_name)
251+
out_disp = apply(dispxfm,img_name)
252+
out_bspl = apply(bsplxfm,img_name)
252253

253254
out_disp.to_filename("resampled_field.nii.gz")
254255
out_bspl.to_filename("resampled_bsplines.nii.gz")

0 commit comments

Comments
 (0)