Skip to content

Commit 23d35f9

Browse files
add analytical derivatives for splines
1 parent 6bae7ca commit 23d35f9

File tree

4 files changed

+243
-4
lines changed

4 files changed

+243
-4
lines changed

pina/model/spline.py

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def __init__(self, order=4, knots=None, control_points=None):
139139
# Precompute boundary interval index
140140
self._boundary_interval_idx = self._compute_boundary_interval()
141141

142+
# Precompute denominators used in derivative formulas
143+
self._compute_derivative_denominators()
144+
142145
def _compute_boundary_interval(self):
143146
"""
144147
Precompute the index of the rightmost non-degenerate interval to improve
@@ -163,15 +166,49 @@ def _compute_boundary_interval(self):
163166
# Otherwise, return the last valid index
164167
return int(valid[-1])
165168

166-
def basis(self, x):
169+
def _compute_derivative_denominators(self):
170+
"""
171+
Precompute the denominators used in the derivatives for all orders up to
172+
the spline order to avoid redundant calculations.
173+
"""
174+
# Precompute for orders 2 to k
175+
for i in range(2, self.order + 1):
176+
177+
# Denominators for the derivative recurrence relations
178+
left_den = self.knots[i - 1 : -1] - self.knots[:-i]
179+
right_den = self.knots[i:] - self.knots[1 : -i + 1]
180+
181+
# If consecutive knots are equal, set left and right factors to zero
182+
left_fac = torch.where(
183+
torch.abs(left_den) > 1e-10,
184+
(i - 1) / left_den,
185+
torch.zeros_like(left_den),
186+
)
187+
right_fac = torch.where(
188+
torch.abs(right_den) > 1e-10,
189+
(i - 1) / right_den,
190+
torch.zeros_like(right_den),
191+
)
192+
193+
# Register buffers
194+
self.register_buffer(f"_left_factor_order_{i}", left_fac)
195+
self.register_buffer(f"_right_factor_order_{i}", right_fac)
196+
197+
def basis(self, x, collection=False):
167198
"""
168199
Compute the basis functions for the spline using an iterative approach.
169200
This is a vectorized implementation based on the Cox-de Boor recursion.
170201
171202
:param torch.Tensor x: The points to be evaluated.
203+
:param bool collection: If True, returns a list of basis functions for
204+
all orders up to the spline order. Default is False.
205+
:raise ValueError: If ``collection`` is not a boolean.
172206
:return: The basis functions evaluated at x.
173-
:rtype: torch.Tensor
207+
:rtype: torch.Tensor | list[torch.Tensor]
174208
"""
209+
# Check consistency
210+
check_consistency(collection, bool)
211+
175212
# Add a final dimension to x
176213
x = x.unsqueeze(-1)
177214

@@ -201,6 +238,10 @@ def basis(self, x):
201238
at_rightmost_boundary,
202239
).to(basis.dtype)
203240

241+
# If returning the whole collection, initialize list
242+
if collection:
243+
basis_collection = [None, basis]
244+
204245
# Iterative case of recursion
205246
for i in range(1, self.order):
206247

@@ -222,8 +263,10 @@ def basis(self, x):
222263

223264
# Combine terms to get the new basis
224265
basis = term1 + term2
266+
if collection:
267+
basis_collection.append(basis)
225268

226-
return basis
269+
return basis_collection if collection else basis
227270

228271
def forward(self, x):
229272
"""
@@ -240,6 +283,72 @@ def forward(self, x):
240283
self.control_points,
241284
)
242285

286+
def derivative(self, x, degree):
287+
"""
288+
Compute the ``degree``-th derivative of the spline at given points.
289+
290+
:param x: The input tensor.
291+
:type x: torch.Tensor | LabelTensor
292+
:param int degree: The derivative degree to compute.
293+
:raise ValueError: If ``degree`` is not an integer.
294+
:return: The derivative tensor.
295+
:rtype: torch.Tensor
296+
"""
297+
# Check consistency
298+
check_positive_integer(degree, strict=False)
299+
300+
# Compute basis derivative
301+
der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree)
302+
303+
return torch.einsum("...bi, i -> ...b", der, self.control_points)
304+
305+
def _basis_derivative(self, x, degree):
306+
"""
307+
Compute the ``degree``-th derivative of the spline basis functions at
308+
given points using an iterative approach.
309+
310+
:param torch.Tensor x: The points to be evaluated.
311+
:param int degree: The derivative degree to compute.
312+
:return: The basis functions evaluated at x.
313+
:rtype: torch.Tensor
314+
"""
315+
# Compute the whole basis collection
316+
basis = self.basis(x, collection=True)
317+
318+
# Derivatives initialization (with dummy at index 0 for convenience)
319+
derivatives = [None] + [basis[o] for o in range(1, self.order + 1)]
320+
321+
# Iterate over derivative degrees
322+
for _ in range(1, degree + 1):
323+
324+
# Current degree derivatives (with dummy at index 0 for convenience)
325+
current_der = [None] * (self.order + 1)
326+
current_der[1] = torch.zeros_like(derivatives[1])
327+
328+
# Iterate over basis orders
329+
for o in range(2, self.order + 1):
330+
331+
# Retrieve precomputed factors
332+
left_fac = getattr(self, f"_left_factor_order_{o}")
333+
right_fac = getattr(self, f"_right_factor_order_{o}")
334+
335+
# Slice previous derivatives to align
336+
left_part = derivatives[o - 1][..., :-1]
337+
right_part = derivatives[o - 1][..., 1:]
338+
339+
# Broadcast factors over batch dims
340+
view_shape = (1,) * (left_part.ndim - 1) + (-1,)
341+
left_fac = left_fac.reshape(*view_shape)
342+
right_fac = right_fac.reshape(*view_shape)
343+
344+
# Compute current derivatives
345+
current_der[o] = left_fac * left_part - right_fac * right_part
346+
347+
# Update derivatives for next degree
348+
derivatives = current_der
349+
350+
return derivatives[self.order].squeeze(-1)
351+
243352
@property
244353
def control_points(self):
245354
"""
@@ -364,3 +473,6 @@ def knots(self, value):
364473
# Recompute boundary interval when knots change
365474
if hasattr(self, "_boundary_interval_idx"):
366475
self._boundary_interval_idx = self._compute_boundary_interval()
476+
477+
# Recompute derivative denominators when knots change
478+
self._compute_derivative_denominators()

pina/model/spline_surface.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import torch
44
from .spline import Spline
5-
from ..utils import check_consistency
5+
from ..label_tensor import LabelTensor
6+
from ..utils import check_consistency, check_positive_integer
67

78

89
class SplineSurface(torch.nn.Module):
@@ -122,6 +123,71 @@ def forward(self, x):
122123
self.control_points,
123124
).unsqueeze(-1)
124125

126+
def derivative(self, x, degree_u, degree_v):
127+
"""
128+
Compute the partial derivatives of the spline at the given points.
129+
130+
:param x: The input tensor.
131+
:type x: torch.Tensor | LabelTensor
132+
:param int degree_u: The degree of the derivative along the first
133+
parameter direction.
134+
:param int degree_v: The degree of the derivative along the second
135+
parameter direction.
136+
:raise ValueError: If ``degree_u`` is not an integer.
137+
:raise ValueError: If ``degree_v`` is not an integer.
138+
:return: The derivative tensor.
139+
:rtype: torch.Tensor
140+
"""
141+
# Check consistency
142+
check_positive_integer(degree_u, strict=False)
143+
check_positive_integer(degree_v, strict=False)
144+
145+
# Split input into u and v components
146+
if isinstance(x, LabelTensor):
147+
u = x[x.labels[0]].as_subclass(torch.Tensor)
148+
v = x[x.labels[1]].as_subclass(torch.Tensor)
149+
else:
150+
u = x[..., 0]
151+
v = x[..., 1]
152+
153+
# Compute basis derivatives
154+
der_u = self.spline_u._basis_derivative(u, degree=degree_u)
155+
der_v = self.spline_v._basis_derivative(v, degree=degree_v)
156+
157+
return torch.einsum(
158+
"...bi, ...bj, ij -> ...b", der_u, der_v, self.control_points
159+
)
160+
161+
def gradient(self, x):
162+
"""
163+
Convenience method to compute the gradient of the spline surface.
164+
165+
:param x: The input tensor.
166+
:type x: torch.Tensor | LabelTensor
167+
:return: The gradient tensor.
168+
:rtype: torch.Tensor
169+
"""
170+
# Compute partial derivatives
171+
du = self.derivative(x, degree_u=1, degree_v=0)
172+
dv = self.derivative(x, degree_u=0, degree_v=1)
173+
174+
return torch.cat((du, dv), dim=-1)
175+
176+
def laplacian(self, x):
177+
"""
178+
Convenience method to compute the laplacian of the spline surface.
179+
180+
:param x: The input tensor.
181+
:type x: torch.Tensor | LabelTensor
182+
:return: The laplacian tensor.
183+
:rtype: torch.Tensor
184+
"""
185+
# Compute second partial derivatives
186+
ddu = self.derivative(x, degree_u=2, degree_v=0)
187+
ddv = self.derivative(x, degree_u=0, degree_v=2)
188+
189+
return ddu + ddv
190+
125191
@property
126192
def knots(self):
127193
"""

tests/test_model/test_spline.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import pytest
33
from scipy.interpolate import BSpline
4+
from pina.operator import grad
45
from pina.model import Spline
56
from pina import LabelTensor
67

@@ -173,3 +174,21 @@ def test_backward(args, pts):
173174
loss = torch.mean(output_)
174175
loss.backward()
175176
assert model.control_points.grad.shape == model.control_points.shape
177+
178+
179+
@pytest.mark.parametrize("args", valid_args)
180+
@pytest.mark.parametrize("pts", points)
181+
def test_derivative(args, pts):
182+
183+
# Define and evaluate the model
184+
model = Spline(**args)
185+
pts.requires_grad_(True)
186+
output_ = LabelTensor(model(pts), "u")
187+
188+
# Compute derivatives
189+
first_der = model.derivative(x=pts, degree=1)
190+
first_der_auto = grad(output_, pts).tensor
191+
192+
# Check shape and value
193+
assert first_der.shape == pts.shape
194+
assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4)

tests/test_model/test_spline_surface.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
import pytest
44
from pina.model import SplineSurface
5+
from pina.operator import grad
56
from pina import LabelTensor
67

78

@@ -178,3 +179,44 @@ def test_backward(knots_u, knots_v, control_points, pts):
178179
loss = torch.mean(output_)
179180
loss.backward()
180181
assert model.control_points.grad.shape == model.control_points.shape
182+
183+
184+
@pytest.mark.parametrize(
185+
"knots_u",
186+
[
187+
torch.rand(n_knots[0]),
188+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
189+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
190+
],
191+
)
192+
@pytest.mark.parametrize(
193+
"knots_v",
194+
[
195+
torch.rand(n_knots[1]),
196+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
197+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
198+
],
199+
)
200+
@pytest.mark.parametrize(
201+
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
202+
)
203+
@pytest.mark.parametrize("pts", points)
204+
def test_derivative(knots_u, knots_v, control_points, pts):
205+
206+
# Define and evaluate the model
207+
model = SplineSurface(
208+
orders=orders,
209+
knots_u=knots_u,
210+
knots_v=knots_v,
211+
control_points=control_points,
212+
)
213+
pts.requires_grad_(True)
214+
output_ = LabelTensor(model(pts), "u")
215+
216+
# Compute derivatives
217+
gradient = model.gradient(x=pts)
218+
gradient_auto = grad(output_, pts).tensor
219+
220+
# Check shape and value
221+
assert gradient.shape == pts.shape
222+
assert torch.allclose(gradient, gradient_auto, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)