Skip to content

Commit 61c6986

Browse files
committed
OPT: Pre-filter dense field transform to speed up repeated map calls
1 parent 7845e64 commit 61c6986

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

nitransforms/nonlinear.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
class DenseFieldTransform(TransformBase):
3737
"""Represents dense field (voxel-wise) transforms."""
3838

39-
__slots__ = ("_field", "_deltas", "_is_deltas")
39+
__slots__ = ("_field", "_deltas", "_is_deltas", "_filtered_field")
4040

41-
def __init__(self, field=None, is_deltas=True, reference=None):
41+
def __init__(self, field=None, is_deltas=True, reference=None, do_prefilter=True):
4242
"""
4343
Create a dense field transform.
4444
@@ -107,6 +107,17 @@ def __init__(self, field=None, is_deltas=True, reference=None):
107107
else:
108108
self._field = _data.copy()
109109

110+
self._filtered_field = None
111+
if do_prefilter:
112+
# pre-cache filtered field to accelerate later mapping
113+
from scipy.ndimage import spline_filter
114+
for i in range(self.reference.ndim):
115+
filtered_field_i = spline_filter(self._field[..., i], order=3, output=np.float64, mode='constant')
116+
if self._filtered_field is None:
117+
self._filtered_field = np.repeat(filtered_field_i[..., np.newaxis], self.reference.ndim, axis=-1)
118+
else:
119+
self._filtered_field[..., i] = filtered_field_i
120+
110121
def __repr__(self):
111122
"""Beautify the python representation."""
112123
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
@@ -193,12 +204,12 @@ def map(self, x, inverse=False):
193204
mapped_coords = np.vstack(
194205
tuple(
195206
map_coordinates(
196-
self._field[..., i],
207+
self._field[..., i] if self._filtered_field is None else self._filtered_field[..., i],
197208
ijk.T,
198209
order=3,
199210
mode="constant",
200211
cval=np.nan,
201-
prefilter=True,
212+
prefilter=self._filtered_field is None,
202213
)
203214
for i in range(self.reference.ndim)
204215
)

0 commit comments

Comments
 (0)