Skip to content

Commit f59720d

Browse files
Julien Marabottooesteban
authored andcommitted
FIX: Offsource Apply
Apply function offsourced. Tests: 139 passed, 163 Skipped, 15 Warnings
1 parent 9f93a67 commit f59720d

File tree

3 files changed

+31
-62
lines changed

3 files changed

+31
-62
lines changed

nitransforms/nonlinear.py

Lines changed: 15 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class DenseFieldTransform(TransformBase):
3030

3131
__slots__ = ("_field", "_deltas")
3232

33+
@property
34+
def ndim(self):
35+
"""Access the dimensions of this Desne Field Transform."""
36+
return self._field.ndim - 1
37+
3338
def __init__(self, field=None, is_deltas=True, reference=None):
3439
"""
3540
Create a dense field transform.
@@ -82,11 +87,10 @@ def __init__(self, field=None, is_deltas=True, reference=None):
8287
"Reference is not a spatial image"
8388
)
8489

85-
ndim = self._field.ndim - 1
86-
if self._field.shape[-1] != ndim:
90+
if self._field.shape[-1] != self.ndim:
8791
raise TransformError(
8892
"The number of components of the field (%d) does not match "
89-
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
93+
"the number of dimensions (%d)" % (self._field.shape[-1], self.ndim)
9094
)
9195

9296
if is_deltas:
@@ -245,6 +249,12 @@ class BSplineFieldTransform(TransformBase):
245249

246250
__slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving']
247251

252+
@property
253+
def ndim(self):
254+
"""Access the dimensions of this BSpline."""
255+
#return ndim = self._coeffs.shape[-1]
256+
return self._coeffs.ndim - 1
257+
248258
def __init__(self, coefficients, reference=None, order=3):
249259
"""Create a smooth deformation field using B-Spline basis."""
250260
super().__init__()
@@ -277,66 +287,19 @@ def to_field(self, reference=None, dtype="float32"):
277287
if _ref is None:
278288
raise TransformError("A reference must be defined")
279289

280-
ndim = self._coeffs.shape[-1]
281-
282290
if self._weights is None:
283291
self._weights = grid_bspline_weights(_ref, self._knots)
284292

285-
field = np.zeros((_ref.npoints, ndim))
293+
field = np.zeros((_ref.npoints, self.ndim))
286294

287-
for d in range(ndim):
295+
for d in range(self.ndim):
288296
# 1 x Nvox : (1 x K) @ (K x Nvox)
289297
field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights
290298

291299
return DenseFieldTransform(
292300
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
293301
)
294302

295-
def apply(
296-
self,
297-
spatialimage,
298-
reference=None,
299-
order=3,
300-
mode="constant",
301-
cval=0.0,
302-
prefilter=True,
303-
output_dtype=None,
304-
):
305-
"""Apply a B-Spline transform on input data."""
306-
307-
_ref = (
308-
self.reference if reference is None else
309-
SpatialReference.factory(_ensure_image(reference))
310-
)
311-
spatialimage = _ensure_image(spatialimage)
312-
313-
# If locations to be interpolated are not on a grid, run map()
314-
#import pdb; pdb.set_trace()
315-
if not isinstance(_ref, ImageGrid):
316-
return apply(
317-
super(),
318-
spatialimage,
319-
reference=_ref,
320-
output_dtype=output_dtype,
321-
order=order,
322-
mode=mode,
323-
cval=cval,
324-
prefilter=prefilter,
325-
326-
)
327-
328-
# If locations to be interpolated are on a grid, generate a displacements field
329-
return apply(
330-
self.to_field(reference=reference),
331-
spatialimage,
332-
reference=reference,
333-
order=order,
334-
mode=mode,
335-
cval=cval,
336-
prefilter=prefilter,
337-
output_dtype=output_dtype,
338-
)
339-
340303
def map(self, x, inverse=False):
341304
r"""
342305
Apply the transformation to a list of physical coordinate points.

nitransforms/resampling.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,21 @@ def apply(
8787
spatialimage = _nbload(str(spatialimage))
8888

8989
data = np.asanyarray(spatialimage.dataobj)
90-
91-
targets = ImageGrid(spatialimage).index( # data should be an image
92-
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
93-
)
9490

9591
if data.ndim == 4 and data.shape[-1] != len(transform):
9692
raise ValueError("The fourth dimension of the data does not match the tranform's shape.")
9793

9894
if data.ndim < transform.ndim:
9995
data = data[..., np.newaxis]
96+
97+
if hasattr(transform, 'to_field') and callable(transform.to_field):
98+
targets = ImageGrid(spatialimage).index(
99+
_as_homogeneous(transform.to_field(reference=reference).map(_ref.ndcoords.T), dim=_ref.ndim)
100+
)
101+
else:
102+
targets = ImageGrid(spatialimage).index( # data should be an image
103+
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
104+
)
100105

101106
if transform.ndim == 4:
102107
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T

nitransforms/tests/test_nonlinear.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,14 @@ def test_bsplines_references(testdata_path):
9797
).to_field()
9898

9999
with pytest.raises(TransformError):
100-
BSplineFieldTransform(
101-
testdata_path / "someones_bspline_coefficients.nii.gz"
102-
).apply(testdata_path / "someones_anatomy.nii.gz")
100+
apply(
101+
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
102+
testdata_path / "someones_anatomy.nii.gz",
103+
)
103104

104-
BSplineFieldTransform(
105-
testdata_path / "someones_bspline_coefficients.nii.gz"
106-
).apply(
105+
apply(
106+
BSplineFieldTransform(
107+
testdata_path / "someones_bspline_coefficients.nii.gz"),
107108
testdata_path / "someones_anatomy.nii.gz",
108109
reference=testdata_path / "someones_anatomy.nii.gz"
109110
)

0 commit comments

Comments
 (0)