Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
class DenseFieldTransform(TransformBase):
"""Represents dense field (voxel-wise) transforms."""

__slots__ = ("_field", "_deltas", "_is_deltas")
__slots__ = ("_field", "_deltas", "_is_deltas", "_filtered_field")

def __init__(self, field=None, is_deltas=True, reference=None):
def __init__(self, field=None, is_deltas=True, reference=None, do_prefilter=True):
"""
Create a dense field transform.

Expand Down Expand Up @@ -107,6 +107,17 @@ def __init__(self, field=None, is_deltas=True, reference=None):
else:
self._field = _data.copy()

self._filtered_field = None
if do_prefilter:
# pre-cache filtered field to accelerate later mapping
from scipy.ndimage import spline_filter
for i in range(self.reference.ndim):
filtered_field_i = spline_filter(self._field[..., i], order=3, output=np.float64, mode='constant')
if self._filtered_field is None:
self._filtered_field = np.repeat(filtered_field_i[..., np.newaxis], self.reference.ndim, axis=-1)
else:
self._filtered_field[..., i] = filtered_field_i

def __repr__(self):
"""Beautify the python representation."""
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
Expand Down Expand Up @@ -193,12 +204,12 @@ def map(self, x, inverse=False):
mapped_coords = np.vstack(
tuple(
map_coordinates(
self._field[..., i],
self._field[..., i] if self._filtered_field is None else self._filtered_field[..., i],
ijk.T,
order=3,
mode="constant",
cval=np.nan,
prefilter=True,
prefilter=self._filtered_field is None,
)
for i in range(self.reference.ndim)
)
Expand Down