Skip to content

Commit 4ff32ac

Browse files
committed
enh: base implementation of B-Spline transforms
1 parent f703445 commit 4ff32ac

File tree

3 files changed

+223
-71
lines changed

3 files changed

+223
-71
lines changed

nitransforms/interp/__init__.py

Whitespace-only changes.

nitransforms/interp/bspline.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
4+
#
5+
# See COPYING file distributed along with the NiBabel package for the
6+
# copyright and license terms.
7+
#
8+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9+
"""Interpolate with 3D tensor-product B-Spline basis."""
10+
import numpy as np
11+
import nibabel as nb
12+
from scipy.sparse import csr_matrix, kron
13+
14+
15+
def _cubic_bspline(d, order=3):
16+
"""Evaluate the cubic bspline at distance d from the center."""
17+
if order != 3:
18+
raise NotImplementedError
19+
20+
return np.piecewise(
21+
d,
22+
[d < 1.0, d >= 1.0],
23+
[
24+
lambda d: (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3) / 6.0,
25+
lambda d: (2.0 - d) ** 3 / 6.0,
26+
],
27+
)
28+
29+
30+
def grid_bspline_weights(target_nii, ctrl_nii):
31+
r"""
32+
Evaluate tensor-product B-Spline weights on a grid.
33+
34+
For each of the *N* input samples :math:`(s_1, s_2, s_3)` and *K* control
35+
points or *knots* :math:`\mathbf{k} =(k_1, k_2, k_3)`, the tensor-product
36+
cubic B-Spline kernel weights are calculated:
37+
38+
.. math::
39+
\Psi^3(\mathbf{k}, \mathbf{s}) =
40+
\beta^3(s_1 - k_1) \cdot \beta^3(s_2 - k_2) \cdot \beta^3(s_3 - k_3),
41+
\label{eq:1}\tag{1}
42+
43+
where each :math:`\beta^3` represents the cubic B-Spline for one dimension.
44+
The 1D B-Spline kernel implementation uses :obj:`numpy.piecewise`, and is based on the
45+
closed-form given by Eq. (6) of [Unser1999]_.
46+
47+
By iterating over dimensions, the data samples that fall outside of the compact
48+
support of the tensor-product kernel associated to each control point can be filtered
49+
out and dismissed to lighten computation.
50+
51+
Finally, the resulting weights matrix :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
52+
can be easily identified in Eq. :math:`\eqref{eq:1}` and used as the design matrix
53+
for approximation of data.
54+
55+
Parameters
56+
----------
57+
target_nii : :obj:`nibabel.spatialimages`
58+
An spatial image object (typically, a :obj:`~nibabel.nifti1.Nifti1Image`)
59+
embedding the target EPI image to be corrected.
60+
Provides the location of the *N* samples (total number of voxels) in the space.
61+
ctrl_nii : :obj:`nibabel.spatialimages`
62+
An spatial image object (typically, a :obj:`~nibabel.nifti1.Nifti1Image`)
63+
embedding the location of the control points of the B-Spline grid.
64+
The data array should contain a total of :math:`K` knots (control points).
65+
66+
Returns
67+
-------
68+
weights : :obj:`numpy.ndarray` (:math:`K \times N`)
69+
A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
70+
for the *N* voxels of the target EPI, for each of the total *K* knots.
71+
This sparse matrix can be directly used as design matrix for the fitting
72+
step of approximation/extrapolation.
73+
74+
"""
75+
shape = target_nii.shape[:3]
76+
ctrl_sp = ctrl_nii.header.get_zooms()[:3]
77+
ras2ijk = np.linalg.inv(ctrl_nii.affine)
78+
origin = nb.affines.apply_affine(ras2ijk, [tuple(target_nii.affine[:3, 3])])[0]
79+
80+
wd = []
81+
for i, (o, n, sp) in enumerate(
82+
zip(origin, shape, target_nii.header.get_zooms()[:3])
83+
):
84+
locations = np.arange(0, n, dtype="float16") * sp / ctrl_sp[i] + o
85+
knots = np.arange(0, ctrl_nii.shape[i], dtype="float16")
86+
distance = np.abs(locations[np.newaxis, ...] - knots[..., np.newaxis])
87+
88+
within_support = distance < 2.0
89+
d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True)
90+
bs_w = _cubic_bspline(d_vals)
91+
weights = np.zeros_like(distance, dtype="float32")
92+
weights[within_support] = bs_w[d_idxs]
93+
wd.append(csr_matrix(weights))
94+
95+
return kron(kron(wd[0], wd[1]), wd[2])

nitransforms/nonlinear.py

Lines changed: 128 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,21 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Nonlinear transforms."""
1010
import warnings
11+
from pathlib import Path
1112
import numpy as np
12-
from .base import TransformBase
13-
from . import io
13+
from scipy.sparse import sparse_vstack
14+
from scipy import ndimage as ndi
15+
from nibabel.funcs import four_to_three
16+
from nibabel.loadsave import load as _nbload
1417

15-
# from .base import ImageGrid
16-
# from nibabel.funcs import four_to_three
18+
from . import io
19+
from .interp.bspline import grid_bspline_weights
20+
from .base import (
21+
TransformBase,
22+
ImageGrid,
23+
SpatialReference,
24+
_as_homogeneous,
25+
)
1726

1827

1928
class DisplacementsFieldTransform(TransformBase):
@@ -90,70 +99,118 @@ def from_filename(cls, filename, fmt="X5"):
9099

91100
load = DisplacementsFieldTransform.from_filename
92101

93-
# class BSplineFieldTransform(TransformBase):
94-
# """Represent a nonlinear transform parameterized by BSpline basis."""
95-
96-
# __slots__ = ['_coeffs', '_knots', '_refknots', '_order', '_moving']
97-
# __s = (slice(None), )
98-
99-
# def __init__(self, reference, coefficients, order=3):
100-
# """Create a smooth deformation field using B-Spline basis."""
101-
# super(BSplineFieldTransform, self).__init__()
102-
# self._order = order
103-
# self.reference = reference
104-
105-
# if coefficients.shape[-1] != self.ndim:
106-
# raise ValueError(
107-
# 'Number of components of the coefficients does '
108-
# 'not match the number of dimensions')
109-
110-
# self._coeffs = np.asanyarray(coefficients.dataobj)
111-
# self._knots = ImageGrid(four_to_three(coefficients)[0])
112-
# self._cache_moving()
113-
114-
# def _cache_moving(self):
115-
# self._moving = np.zeros((self.reference.shape) + (3, ),
116-
# dtype='float32')
117-
# ijk = np.moveaxis(self.reference.ndindex, 0, -1).reshape(-1, self.ndim)
118-
# xyz = np.moveaxis(self.reference.ndcoords, 0, -1).reshape(-1, self.ndim)
119-
# print(np.shape(xyz))
120-
121-
# for i in range(np.shape(xyz)[0]):
122-
# print(i, xyz[i, :])
123-
# self._moving[tuple(ijk[i]) + self.__s] = self._interp_transform(xyz[i, :])
124-
125-
# def _interp_transform(self, coords):
126-
# # Calculate position in the grid of control points
127-
# knots_ijk = self._knots.inverse.dot(np.hstack((coords, 1)))[:3]
128-
# neighbors = []
129-
# offset = 0.0 if self._order & 1 else 0.5
130-
# # Calculate neighbors along each dimension
131-
# for dim in range(self.ndim):
132-
# first = int(np.floor(knots_ijk[dim] + offset) - self._order // 2)
133-
# neighbors.append(list(range(first, first + self._order + 1)))
134-
135-
# # Get indexes of the neighborings clique
136-
# ndindex = np.moveaxis(
137-
# np.array(np.meshgrid(*neighbors, indexing='ij')), 0, -1).reshape(
138-
# -1, self.ndim)
139-
140-
# # Calculate the tensor B-spline weights of each neighbor
141-
# # weights = np.prod(vbspl(ndindex - knots_ijk), axis=-1)
142-
# ndindex = [tuple(v) for v in ndindex]
143-
144-
# # Retrieve coefficients and deal with boundary conditions
145-
# zero = np.zeros(self.ndim)
146-
# shape = np.array(self._knots.shape)
147-
# coeffs = []
148-
# for ijk in ndindex:
149-
# offbounds = (zero > ijk) | (shape <= ijk)
150-
# coeffs.append(
151-
# self._coeffs[ijk] if not np.any(offbounds)
152-
# else [0.0] * self.ndim)
153-
154-
# # coords[:3] += weights.dot(np.array(coeffs, dtype=float))
155-
# return self.reference.inverse.dot(np.hstack((coords, 1)))[:3]
156-
157-
# def _map_voxel(self, index, moving=None):
158-
# """Apply ijk' = f_ijk((i, j, k)), equivalent to the above with indexes."""
159-
# return tuple(self._moving[index + self.__s])
102+
103+
class BSplineFieldTransform(TransformBase):
104+
"""Represent a nonlinear transform parameterized by BSpline basis."""
105+
106+
__slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving']
107+
__s = (slice(None), )
108+
109+
def __init__(self, reference, coefficients, order=3):
110+
"""Create a smooth deformation field using B-Spline basis."""
111+
super(BSplineFieldTransform, self).__init__()
112+
self._order = order
113+
self.reference = reference
114+
115+
if coefficients.shape[-1] != self.ndim:
116+
raise ValueError(
117+
'Number of components of the coefficients does '
118+
'not match the number of dimensions')
119+
120+
self._coeffs = np.asanyarray(coefficients.dataobj)
121+
self._knots = ImageGrid(four_to_three(coefficients)[0])
122+
self._weights = None
123+
124+
def apply(
125+
self,
126+
spatialimage,
127+
reference=None,
128+
order=3,
129+
mode="constant",
130+
cval=0.0,
131+
prefilter=True,
132+
output_dtype=None,
133+
):
134+
"""Apply a B-Spline transform on input data."""
135+
136+
if reference is not None and isinstance(reference, (str, Path)):
137+
reference = _nbload(str(reference))
138+
139+
_ref = (
140+
self.reference if reference is None else SpatialReference.factory(reference)
141+
)
142+
143+
if isinstance(spatialimage, (str, Path)):
144+
spatialimage = _nbload(str(spatialimage))
145+
146+
if not isinstance(_ref, ImageGrid):
147+
return super().apply(
148+
spatialimage,
149+
reference=reference,
150+
order=order,
151+
mode=mode,
152+
cval=cval,
153+
prefilter=prefilter,
154+
output_dtype=output_dtype,
155+
)
156+
157+
# If locations to be interpolated are on a grid, use faster tensor-bspline calculation
158+
if self._weights is None:
159+
self._weights = grid_bspline_weights(_ref, self._knots)
160+
161+
ycoords = _ref.ndcoords.T + (
162+
np.squeeze(np.hstack(self._coeffs).T) @ sparse_vstack(self._weights)
163+
)
164+
165+
data = np.squeeze(np.asanyarray(spatialimage.dataobj))
166+
output_dtype = output_dtype or data.dtype
167+
targets = ImageGrid(spatialimage).index( # data should be an image
168+
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim)
169+
)
170+
171+
if data.ndim == 4:
172+
if len(self) != data.shape[-1]:
173+
raise ValueError(
174+
"Attempting to apply %d transforms on a file with "
175+
"%d timepoints" % (len(self), data.shape[-1])
176+
)
177+
targets = targets.reshape((len(self), -1, targets.shape[-1]))
178+
resampled = np.stack(
179+
[
180+
ndi.map_coordinates(
181+
data[..., t],
182+
targets[t, ..., : _ref.ndim].T,
183+
output=output_dtype,
184+
order=order,
185+
mode=mode,
186+
cval=cval,
187+
prefilter=prefilter,
188+
)
189+
for t in range(data.shape[-1])
190+
],
191+
axis=0,
192+
)
193+
elif data.ndim in (2, 3):
194+
resampled = ndi.map_coordinates(
195+
data,
196+
targets[..., : _ref.ndim].T,
197+
output=output_dtype,
198+
order=order,
199+
mode=mode,
200+
cval=cval,
201+
prefilter=prefilter,
202+
)
203+
204+
newdata = resampled.reshape((len(self), *_ref.shape))
205+
moved = spatialimage.__class__(
206+
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header
207+
)
208+
moved.header.set_data_dtype(output_dtype)
209+
return moved
210+
211+
def map(self, x, inverse=False):
212+
raise NotImplementedError
213+
214+
def _map_voxel(self, index, moving=None):
215+
"""Apply ijk' = f_ijk((i, j, k)), equivalent to the above with indexes."""
216+
raise NotImplementedError

0 commit comments

Comments
 (0)