Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 115 additions & 3 deletions pina/model/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def __init__(self, order=4, knots=None, control_points=None):
# Precompute boundary interval index
self._boundary_interval_idx = self._compute_boundary_interval()

# Precompute denominators used in derivative formulas
self._compute_derivative_denominators()

def _compute_boundary_interval(self):
"""
Precompute the index of the rightmost non-degenerate interval to improve
Expand All @@ -163,15 +166,49 @@ def _compute_boundary_interval(self):
# Otherwise, return the last valid index
return int(valid[-1])

def basis(self, x):
def _compute_derivative_denominators(self):
"""
Precompute the denominators used in the derivatives for all orders up to
the spline order to avoid redundant calculations.
"""
# Precompute for orders 2 to k
for i in range(2, self.order + 1):

# Denominators for the derivative recurrence relations
left_den = self.knots[i - 1 : -1] - self.knots[:-i]
right_den = self.knots[i:] - self.knots[1 : -i + 1]

# If consecutive knots are equal, set left and right factors to zero
left_fac = torch.where(
torch.abs(left_den) > 1e-10,
(i - 1) / left_den,
torch.zeros_like(left_den),
)
right_fac = torch.where(
torch.abs(right_den) > 1e-10,
(i - 1) / right_den,
torch.zeros_like(right_den),
)

# Register buffers
self.register_buffer(f"_left_factor_order_{i}", left_fac)
self.register_buffer(f"_right_factor_order_{i}", right_fac)

def basis(self, x, collection=False):
"""
Compute the basis functions for the spline using an iterative approach.
This is a vectorized implementation based on the Cox-de Boor recursion.

:param torch.Tensor x: The points to be evaluated.
:param bool collection: If True, returns a list of basis functions for
all orders up to the spline order. Default is False.
:raise ValueError: If ``collection`` is not a boolean.
:return: The basis functions evaluated at x.
:rtype: torch.Tensor
:rtype: torch.Tensor | list[torch.Tensor]
"""
# Check consistency
check_consistency(collection, bool)

# Add a final dimension to x
x = x.unsqueeze(-1)

Expand Down Expand Up @@ -201,6 +238,10 @@ def basis(self, x):
at_rightmost_boundary,
).to(basis.dtype)

# If returning the whole collection, initialize list
if collection:
basis_collection = [None, basis]

# Iterative case of recursion
for i in range(1, self.order):

Expand All @@ -222,8 +263,10 @@ def basis(self, x):

# Combine terms to get the new basis
basis = term1 + term2
if collection:
basis_collection.append(basis)

return basis
return basis_collection if collection else basis

def forward(self, x):
"""
Expand All @@ -240,6 +283,72 @@ def forward(self, x):
self.control_points,
)

def derivative(self, x, degree):
"""
Compute the ``degree``-th derivative of the spline at given points.

:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param int degree: The derivative degree to compute.
:raise ValueError: If ``degree`` is not an integer.
:return: The derivative tensor.
:rtype: torch.Tensor
"""
# Check consistency
check_positive_integer(degree, strict=False)

# Compute basis derivative
der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree)

return torch.einsum("...bi, i -> ...b", der, self.control_points)

def _basis_derivative(self, x, degree):
"""
Compute the ``degree``-th derivative of the spline basis functions at
given points using an iterative approach.

:param torch.Tensor x: The points to be evaluated.
:param int degree: The derivative degree to compute.
:return: The basis functions evaluated at x.
:rtype: torch.Tensor
"""
# Compute the whole basis collection
basis = self.basis(x, collection=True)

# Derivatives initialization (with dummy at index 0 for convenience)
derivatives = [None] + [basis[o] for o in range(1, self.order + 1)]

# Iterate over derivative degrees
for _ in range(1, degree + 1):

# Current degree derivatives (with dummy at index 0 for convenience)
current_der = [None] * (self.order + 1)
current_der[1] = torch.zeros_like(derivatives[1])

# Iterate over basis orders
for o in range(2, self.order + 1):

# Retrieve precomputed factors
left_fac = getattr(self, f"_left_factor_order_{o}")
right_fac = getattr(self, f"_right_factor_order_{o}")

# Slice previous derivatives to align
left_part = derivatives[o - 1][..., :-1]
right_part = derivatives[o - 1][..., 1:]

# Broadcast factors over batch dims
view_shape = (1,) * (left_part.ndim - 1) + (-1,)
left_fac = left_fac.reshape(*view_shape)
right_fac = right_fac.reshape(*view_shape)

# Compute current derivatives
current_der[o] = left_fac * left_part - right_fac * right_part

# Update derivatives for next degree
derivatives = current_der

return derivatives[self.order].squeeze(-1)

@property
def control_points(self):
"""
Expand Down Expand Up @@ -364,3 +473,6 @@ def knots(self, value):
# Recompute boundary interval when knots change
if hasattr(self, "_boundary_interval_idx"):
self._boundary_interval_idx = self._compute_boundary_interval()

# Recompute derivative denominators when knots change
self._compute_derivative_denominators()
68 changes: 67 additions & 1 deletion pina/model/spline_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
from .spline import Spline
from ..utils import check_consistency
from ..label_tensor import LabelTensor
from ..utils import check_consistency, check_positive_integer


class SplineSurface(torch.nn.Module):
Expand Down Expand Up @@ -122,6 +123,71 @@ def forward(self, x):
self.control_points,
).unsqueeze(-1)

def derivative(self, x, degree_u, degree_v):
"""
Compute the partial derivatives of the spline at the given points.

:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param int degree_u: The degree of the derivative along the first
parameter direction.
:param int degree_v: The degree of the derivative along the second
parameter direction.
:raise ValueError: If ``degree_u`` is not an integer.
:raise ValueError: If ``degree_v`` is not an integer.
:return: The derivative tensor.
:rtype: torch.Tensor
"""
# Check consistency
check_positive_integer(degree_u, strict=False)
check_positive_integer(degree_v, strict=False)

# Split input into u and v components
if isinstance(x, LabelTensor):
u = x[x.labels[0]].as_subclass(torch.Tensor)
v = x[x.labels[1]].as_subclass(torch.Tensor)
else:
u = x[..., 0]
v = x[..., 1]

# Compute basis derivatives
der_u = self.spline_u._basis_derivative(u, degree=degree_u)
der_v = self.spline_v._basis_derivative(v, degree=degree_v)

return torch.einsum(
"...bi, ...bj, ij -> ...b", der_u, der_v, self.control_points
)

def gradient(self, x):
"""
Convenience method to compute the gradient of the spline surface.

:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:return: The gradient tensor.
:rtype: torch.Tensor
"""
# Compute partial derivatives
du = self.derivative(x, degree_u=1, degree_v=0)
dv = self.derivative(x, degree_u=0, degree_v=1)

return torch.cat((du, dv), dim=-1)

def laplacian(self, x):
"""
Convenience method to compute the laplacian of the spline surface.

:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:return: The laplacian tensor.
:rtype: torch.Tensor
"""
# Compute second partial derivatives
ddu = self.derivative(x, degree_u=2, degree_v=0)
ddv = self.derivative(x, degree_u=0, degree_v=2)

return ddu + ddv

@property
def knots(self):
"""
Expand Down
19 changes: 19 additions & 0 deletions tests/test_model/test_spline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import pytest
from scipy.interpolate import BSpline
from pina.operator import grad
from pina.model import Spline
from pina import LabelTensor

Expand Down Expand Up @@ -173,3 +174,21 @@ def test_backward(args, pts):
loss = torch.mean(output_)
loss.backward()
assert model.control_points.grad.shape == model.control_points.shape


@pytest.mark.parametrize("args", valid_args)
@pytest.mark.parametrize("pts", points)
def test_derivative(args, pts):

# Define and evaluate the model
model = Spline(**args)
pts.requires_grad_(True)
output_ = LabelTensor(model(pts), "u")

# Compute derivatives
first_der = model.derivative(x=pts, degree=1)
first_der_auto = grad(output_, pts).tensor

# Check shape and value
assert first_der.shape == pts.shape
assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4)
42 changes: 42 additions & 0 deletions tests/test_model/test_spline_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
import pytest
from pina.model import SplineSurface
from pina.operator import grad
from pina import LabelTensor


Expand Down Expand Up @@ -178,3 +179,44 @@ def test_backward(knots_u, knots_v, control_points, pts):
loss = torch.mean(output_)
loss.backward()
assert model.control_points.grad.shape == model.control_points.shape


@pytest.mark.parametrize(
"knots_u",
[
torch.rand(n_knots[0]),
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
],
)
@pytest.mark.parametrize(
"knots_v",
[
torch.rand(n_knots[1]),
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
],
)
@pytest.mark.parametrize(
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
)
@pytest.mark.parametrize("pts", points)
def test_derivative(knots_u, knots_v, control_points, pts):

# Define and evaluate the model
model = SplineSurface(
orders=orders,
knots_u=knots_u,
knots_v=knots_v,
control_points=control_points,
)
pts.requires_grad_(True)
output_ = LabelTensor(model(pts), "u")

# Compute derivatives
gradient = model.gradient(x=pts)
gradient_auto = grad(output_, pts).tensor

# Check shape and value
assert gradient.shape == pts.shape
assert torch.allclose(gradient, gradient_auto, atol=1e-4, rtol=1e-4)