diff --git a/pina/model/spline.py b/pina/model/spline.py index a276a6cfd..77bf18759 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -139,6 +139,9 @@ def __init__(self, order=4, knots=None, control_points=None): # Precompute boundary interval index self._boundary_interval_idx = self._compute_boundary_interval() + # Precompute denominators used in derivative formulas + self._compute_derivative_denominators() + def _compute_boundary_interval(self): """ Precompute the index of the rightmost non-degenerate interval to improve @@ -163,15 +166,49 @@ def _compute_boundary_interval(self): # Otherwise, return the last valid index return int(valid[-1]) - def basis(self, x): + def _compute_derivative_denominators(self): + """ + Precompute the denominators used in the derivatives for all orders up to + the spline order to avoid redundant calculations. + """ + # Precompute for orders 2 to k + for i in range(2, self.order + 1): + + # Denominators for the derivative recurrence relations + left_den = self.knots[i - 1 : -1] - self.knots[:-i] + right_den = self.knots[i:] - self.knots[1 : -i + 1] + + # If consecutive knots are equal, set left and right factors to zero + left_fac = torch.where( + torch.abs(left_den) > 1e-10, + (i - 1) / left_den, + torch.zeros_like(left_den), + ) + right_fac = torch.where( + torch.abs(right_den) > 1e-10, + (i - 1) / right_den, + torch.zeros_like(right_den), + ) + + # Register buffers + self.register_buffer(f"_left_factor_order_{i}", left_fac) + self.register_buffer(f"_right_factor_order_{i}", right_fac) + + def basis(self, x, collection=False): """ Compute the basis functions for the spline using an iterative approach. This is a vectorized implementation based on the Cox-de Boor recursion. :param torch.Tensor x: The points to be evaluated. + :param bool collection: If True, returns a list of basis functions for + all orders up to the spline order. Default is False. + :raise ValueError: If ``collection`` is not a boolean. :return: The basis functions evaluated at x. - :rtype: torch.Tensor + :rtype: torch.Tensor | list[torch.Tensor] """ + # Check consistency + check_consistency(collection, bool) + # Add a final dimension to x x = x.unsqueeze(-1) @@ -201,6 +238,10 @@ def basis(self, x): at_rightmost_boundary, ).to(basis.dtype) + # If returning the whole collection, initialize list + if collection: + basis_collection = [None, basis] + # Iterative case of recursion for i in range(1, self.order): @@ -222,8 +263,10 @@ def basis(self, x): # Combine terms to get the new basis basis = term1 + term2 + if collection: + basis_collection.append(basis) - return basis + return basis_collection if collection else basis def forward(self, x): """ @@ -240,6 +283,72 @@ def forward(self, x): self.control_points, ) + def derivative(self, x, degree): + """ + Compute the ``degree``-th derivative of the spline at given points. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :param int degree: The derivative degree to compute. + :raise ValueError: If ``degree`` is not an integer. + :return: The derivative tensor. + :rtype: torch.Tensor + """ + # Check consistency + check_positive_integer(degree, strict=False) + + # Compute basis derivative + der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree) + + return torch.einsum("...bi, i -> ...b", der, self.control_points) + + def _basis_derivative(self, x, degree): + """ + Compute the ``degree``-th derivative of the spline basis functions at + given points using an iterative approach. + + :param torch.Tensor x: The points to be evaluated. + :param int degree: The derivative degree to compute. + :return: The basis functions evaluated at x. + :rtype: torch.Tensor + """ + # Compute the whole basis collection + basis = self.basis(x, collection=True) + + # Derivatives initialization (with dummy at index 0 for convenience) + derivatives = [None] + [basis[o] for o in range(1, self.order + 1)] + + # Iterate over derivative degrees + for _ in range(1, degree + 1): + + # Current degree derivatives (with dummy at index 0 for convenience) + current_der = [None] * (self.order + 1) + current_der[1] = torch.zeros_like(derivatives[1]) + + # Iterate over basis orders + for o in range(2, self.order + 1): + + # Retrieve precomputed factors + left_fac = getattr(self, f"_left_factor_order_{o}") + right_fac = getattr(self, f"_right_factor_order_{o}") + + # Slice previous derivatives to align + left_part = derivatives[o - 1][..., :-1] + right_part = derivatives[o - 1][..., 1:] + + # Broadcast factors over batch dims + view_shape = (1,) * (left_part.ndim - 1) + (-1,) + left_fac = left_fac.reshape(*view_shape) + right_fac = right_fac.reshape(*view_shape) + + # Compute current derivatives + current_der[o] = left_fac * left_part - right_fac * right_part + + # Update derivatives for next degree + derivatives = current_der + + return derivatives[self.order].squeeze(-1) + @property def control_points(self): """ @@ -364,3 +473,6 @@ def knots(self, value): # Recompute boundary interval when knots change if hasattr(self, "_boundary_interval_idx"): self._boundary_interval_idx = self._compute_boundary_interval() + + # Recompute derivative denominators when knots change + self._compute_derivative_denominators() diff --git a/pina/model/spline_surface.py b/pina/model/spline_surface.py index 30d41bbde..61798fe7e 100644 --- a/pina/model/spline_surface.py +++ b/pina/model/spline_surface.py @@ -2,7 +2,8 @@ import torch from .spline import Spline -from ..utils import check_consistency +from ..label_tensor import LabelTensor +from ..utils import check_consistency, check_positive_integer class SplineSurface(torch.nn.Module): @@ -122,6 +123,71 @@ def forward(self, x): self.control_points, ).unsqueeze(-1) + def derivative(self, x, degree_u, degree_v): + """ + Compute the partial derivatives of the spline at the given points. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :param int degree_u: The degree of the derivative along the first + parameter direction. + :param int degree_v: The degree of the derivative along the second + parameter direction. + :raise ValueError: If ``degree_u`` is not an integer. + :raise ValueError: If ``degree_v`` is not an integer. + :return: The derivative tensor. + :rtype: torch.Tensor + """ + # Check consistency + check_positive_integer(degree_u, strict=False) + check_positive_integer(degree_v, strict=False) + + # Split input into u and v components + if isinstance(x, LabelTensor): + u = x[x.labels[0]].as_subclass(torch.Tensor) + v = x[x.labels[1]].as_subclass(torch.Tensor) + else: + u = x[..., 0] + v = x[..., 1] + + # Compute basis derivatives + der_u = self.spline_u._basis_derivative(u, degree=degree_u) + der_v = self.spline_v._basis_derivative(v, degree=degree_v) + + return torch.einsum( + "...bi, ...bj, ij -> ...b", der_u, der_v, self.control_points + ) + + def gradient(self, x): + """ + Convenience method to compute the gradient of the spline surface. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :return: The gradient tensor. + :rtype: torch.Tensor + """ + # Compute partial derivatives + du = self.derivative(x, degree_u=1, degree_v=0) + dv = self.derivative(x, degree_u=0, degree_v=1) + + return torch.cat((du, dv), dim=-1) + + def laplacian(self, x): + """ + Convenience method to compute the laplacian of the spline surface. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :return: The laplacian tensor. + :rtype: torch.Tensor + """ + # Compute second partial derivatives + ddu = self.derivative(x, degree_u=2, degree_v=0) + ddv = self.derivative(x, degree_u=0, degree_v=2) + + return ddu + ddv + @property def knots(self): """ diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index d22de9f26..8c806580b 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -1,6 +1,7 @@ import torch import pytest from scipy.interpolate import BSpline +from pina.operator import grad from pina.model import Spline from pina import LabelTensor @@ -173,3 +174,21 @@ def test_backward(args, pts): loss = torch.mean(output_) loss.backward() assert model.control_points.grad.shape == model.control_points.shape + + +@pytest.mark.parametrize("args", valid_args) +@pytest.mark.parametrize("pts", points) +def test_derivative(args, pts): + + # Define and evaluate the model + model = Spline(**args) + pts.requires_grad_(True) + output_ = LabelTensor(model(pts), "u") + + # Compute derivatives + first_der = model.derivative(x=pts, degree=1) + first_der_auto = grad(output_, pts).tensor + + # Check shape and value + assert first_der.shape == pts.shape + assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4) diff --git a/tests/test_model/test_spline_surface.py b/tests/test_model/test_spline_surface.py index feab587b5..0c288f932 100644 --- a/tests/test_model/test_spline_surface.py +++ b/tests/test_model/test_spline_surface.py @@ -2,6 +2,7 @@ import random import pytest from pina.model import SplineSurface +from pina.operator import grad from pina import LabelTensor @@ -178,3 +179,44 @@ def test_backward(knots_u, knots_v, control_points, pts): loss = torch.mean(output_) loss.backward() assert model.control_points.grad.shape == model.control_points.shape + + +@pytest.mark.parametrize( + "knots_u", + [ + torch.rand(n_knots[0]), + {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "knots_v", + [ + torch.rand(n_knots[1]), + {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] +) +@pytest.mark.parametrize("pts", points) +def test_derivative(knots_u, knots_v, control_points, pts): + + # Define and evaluate the model + model = SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + pts.requires_grad_(True) + output_ = LabelTensor(model(pts), "u") + + # Compute derivatives + gradient = model.gradient(x=pts) + gradient_auto = grad(output_, pts).tensor + + # Check shape and value + assert gradient.shape == pts.shape + assert torch.allclose(gradient, gradient_auto, atol=1e-4, rtol=1e-4)