From 57e458451d0482d6f390701d6c18ec67bfc0b484 Mon Sep 17 00:00:00 2001 From: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Date: Wed, 15 Oct 2025 17:04:18 +0200 Subject: [PATCH] Revert "Improve and fix spline (#610)" This reverts commit 084428233589046b8adf6f1d3be3005b735b990b. --- docs/source/_rst/_code.rst | 1 - docs/source/_rst/model/spline_surface.rst | 7 - pina/model/__init__.py | 1 - pina/model/spline.py | 411 +++++++--------------- pina/model/spline_surface.py | 212 ----------- tests/test_model/test_spline.py | 198 +++-------- tests/test_model/test_spline_surface.py | 180 ---------- 7 files changed, 174 insertions(+), 836 deletions(-) delete mode 100644 docs/source/_rst/model/spline_surface.rst delete mode 100644 pina/model/spline_surface.py delete mode 100644 tests/test_model/test_spline_surface.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 151699449..965a286b5 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -95,7 +95,6 @@ Models MultiFeedForward ResidualFeedForward Spline - SplineSurface DeepONet MIONet KernelNeuralOperator diff --git a/docs/source/_rst/model/spline_surface.rst b/docs/source/_rst/model/spline_surface.rst deleted file mode 100644 index 6bbf137d8..000000000 --- a/docs/source/_rst/model/spline_surface.rst +++ /dev/null @@ -1,7 +0,0 @@ -Spline Surface -================ -.. currentmodule:: pina.model.spline_surface - -.. autoclass:: SplineSurface - :members: - :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 05ccc6c8c..1edeacd1a 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -26,7 +26,6 @@ from .average_neural_operator import AveragingNeuralOperator from .low_rank_neural_operator import LowRankNeuralOperator from .spline import Spline -from .spline_surface import SplineSurface from .graph_neural_operator import GraphNeuralOperator from .pirate_network import PirateNet from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator diff --git a/pina/model/spline.py b/pina/model/spline.py index a276a6cfd..c22c7937c 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -1,244 +1,109 @@ -"""Module for the B-Spline model class.""" +"""Module for the Spline model class.""" -import warnings import torch -from ..utils import check_positive_integer, check_consistency +from ..utils import check_consistency class Spline(torch.nn.Module): - r""" - The univariate B-Spline curve model class. - - A univariate B-spline curve of order :math:`k` is a parametric curve defined - as a linear combination of B-spline basis functions and control points: - - .. math:: - - S(x) = \sum_{i=1}^{n} B_{i,k}(x) C_i, \quad x \in [x_1, x_m] - - where: - - - :math:`C_i \in \mathbb{R}` are the control points. These fixed points - influence the shape of the curve but are not generally interpolated, - except at the boundaries under certain knot multiplicities. - - :math:`B_{i,k}(x)` are the B-spline basis functions of order :math:`k`, - i.e., piecewise polynomials of degree :math:`k-1` with support on the - interval :math:`[x_i, x_{i+k}]`. - - :math:`X = \{ x_1, x_2, \dots, x_m \}` is the non-decreasing knot vector. - - If the first and last knots are repeated :math:`k` times, then the curve - interpolates the first and last control points. - - - .. note:: - - The curve is forced to be zero outside the interval defined by the - first and last knots. - - - :Example: - - >>> from pina.model import Spline - >>> import torch - - >>> knots1 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) - >>> spline1 = Spline(order=3, knots=knots1, control_points=None) - - >>> knots2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto"} - >>> spline2 = Spline(order=3, knots=knots2, control_points=None) - - >>> knots3 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) - >>> control_points3 = torch.tensor([0.0, 1.0, 3.0, 2.0]) - >>> spline3 = Spline(order=3, knots=knots3, control_points=control_points3) + """ + Spline model class. """ - def __init__(self, order=4, knots=None, control_points=None): + def __init__(self, order=4, knots=None, control_points=None) -> None: """ Initialization of the :class:`Spline` class. - :param int order: The order of the spline. The corresponding basis - functions are polynomials of degree ``order - 1``. Default is 4. - :param knots: The knots of the spline. If a tensor is provided, knots - are set directly from the tensor. If a dictionary is provided, it - must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``. - Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"`` - define the interval, and ``"mode"`` selects the sampling strategy. - The supported modes are ``"uniform"``, where the knots are evenly - spaced over :math:`[min, max]`, and ``"auto"``, where knots are - constructed to ensure that the spline interpolates the first and - last control points. In this case, the number of knots is adjusted - if :math:`n < 2 * order`. If None is given, knots are initialized - automatically over :math:`[0, 1]` ensuring interpolation of the - first and last control points. Default is None. - :type knots: torch.Tensor | dict - :param torch.Tensor control_points: The control points of the spline. - If None, they are initialized as learnable parameters with an - initial value of zero. Default is None. - :raises AssertionError: If ``order`` is not a positive integer. - :raises ValueError: If ``knots`` is neither a torch.Tensor nor a - dictionary, when provided. - :raises ValueError: If ``control_points`` is not a torch.Tensor, - when provided. - :raises ValueError: If both ``knots`` and ``control_points`` are None. - :raises ValueError: If ``knots`` is not one-dimensional. - :raises ValueError: If ``control_points`` is not one-dimensional. - :raises ValueError: If the number of ``knots`` is not equal to the sum - of ``order`` and the number of ``control_points.`` - :raises UserWarning: If the number of control points is lower than the - order, resulting in a degenerate spline. + :param int order: The order of the spline. Default is ``4``. + :param torch.Tensor knots: The tensor representing knots. If ``None``, + the knots will be initialized automatically. Default is ``None``. + :param torch.Tensor control_points: The control points. Default is + ``None``. + :raises ValueError: If the order is negative. + :raises ValueError: If both knots and control points are ``None``. + :raises ValueError: If the knot tensor is not one-dimensional. """ super().__init__() - # Check consistency - check_positive_integer(value=order, strict=True) - check_consistency(knots, (type(None), torch.Tensor, dict)) - check_consistency(control_points, (type(None), torch.Tensor)) + check_consistency(order, int) - # Raise error if neither knots nor control points are provided + if order < 0: + raise ValueError("Spline order cannot be negative.") if knots is None and control_points is None: - raise ValueError("knots and control_points cannot both be None.") - - # Initialize knots if not provided - if knots is None and control_points is not None: - knots = { - "n": len(control_points) + order, - "min": 0, - "max": 1, - "mode": "auto", - } + raise ValueError("Knots and control points cannot be both None.") - # Initialization - knots and control points managed by their setters self.order = order - self.knots = knots - self.control_points = control_points - - # Check dimensionality of knots - if self.knots.ndim > 1: - raise ValueError("knots must be one-dimensional.") - - # Check dimensionality of control points - if self.control_points.ndim > 1: - raise ValueError("control_points must be one-dimensional.") - - # Raise error if #knots != order + #control_points - if len(self.knots) != self.order + len(self.control_points): - raise ValueError( - f" The number of knots must be equal to order + number of" - f" control points. Got {len(self.knots)} knots, {self.order}" - f" order and {len(self.control_points)} control points." - ) + self.k = order - 1 - # Raise warning if spline is degenerate - if len(self.control_points) < self.order: - warnings.warn( - "The number of control points is smaller than the spline order." - " This creates a degenerate spline with limited flexibility.", - UserWarning, - ) + if knots is not None and control_points is not None: + self.knots = knots + self.control_points = control_points - # Precompute boundary interval index - self._boundary_interval_idx = self._compute_boundary_interval() + elif knots is not None: + print("Warning: control points will be initialized automatically.") + print(" experimental feature") - def _compute_boundary_interval(self): - """ - Precompute the index of the rightmost non-degenerate interval to improve - performance, eliminating the need to perform a search loop in the basis - function on each call. + self.knots = knots + n = len(knots) - order + self.control_points = torch.nn.Parameter( + torch.zeros(n), requires_grad=True + ) - :return: The index of the rightmost non-degenerate interval. - :rtype: int - """ - # Return 0 if there is a single interval - if len(self.knots) < 2: - return 0 + elif control_points is not None: + print("Warning: knots will be initialized automatically.") + print(" experimental feature") - # Find all indices where knots are strictly increasing - diffs = self.knots[1:] - self.knots[:-1] - valid = torch.nonzero(diffs > 0, as_tuple=False) + self.control_points = control_points - # If all knots are equal, return 0 for degenerate spline - if valid.numel() == 0: - return 0 + n = len(self.control_points) - 1 + self.knots = { + "type": "auto", + "min": 0, + "max": 1, + "n": n + 2 + self.order, + } + + else: + raise ValueError("Knots and control points cannot be both None.") - # Otherwise, return the last valid index - return int(valid[-1]) + if self.knots.ndim != 1: + raise ValueError("Knot vector must be one-dimensional.") - def basis(self, x): + def basis(self, x, k, i, t): """ - Compute the basis functions for the spline using an iterative approach. - This is a vectorized implementation based on the Cox-de Boor recursion. + Recursive method to compute the basis functions of the spline. :param torch.Tensor x: The points to be evaluated. - :return: The basis functions evaluated at x. + :param int k: The spline degree. + :param int i: The index of the interval. + :param torch.Tensor t: The tensor of knots. + :return: The basis functions evaluated at x :rtype: torch.Tensor """ - # Add a final dimension to x - x = x.unsqueeze(-1) - - # Add an initial dimension to knots - knots = self.knots.unsqueeze(0) - # Base case of recursion: indicator functions for the intervals - basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) - basis = basis.to(x.dtype) - - # One-dimensional knots case: ensure rightmost boundary inclusion - if self._boundary_interval_idx is not None: - - # Extract left and right knots of the rightmost interval - knot_left = knots[..., self._boundary_interval_idx] - knot_right = knots[..., self._boundary_interval_idx + 1] - - # Identify points at the rightmost boundary - at_rightmost_boundary = ( - x.squeeze(-1) >= knot_left - ) & torch.isclose(x.squeeze(-1), knot_right, rtol=1e-8, atol=1e-10) - - # Ensure the correct value is set at the rightmost boundary - if torch.any(at_rightmost_boundary): - basis[..., self._boundary_interval_idx] = torch.logical_or( - basis[..., self._boundary_interval_idx].bool(), - at_rightmost_boundary, - ).to(basis.dtype) - - # Iterative case of recursion - for i in range(1, self.order): - - # Compute the denominators for both terms - denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] - denom2 = knots[..., i + 1 :] - knots[..., 1:-i] - - # Ensure no division by zero - denom1 = torch.where( - torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1 + if k == 0: + a = torch.where( + torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0 ) - denom2 = torch.where( - torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2 + if i == len(t) - self.order - 1: + a = torch.where(x == t[-1], 1.0, a) + a.requires_grad_(True) + return a + + if t[i + k] == t[i]: + c1 = torch.tensor([0.0] * len(x), requires_grad=True) + else: + c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t) + + if t[i + k + 1] == t[i + 1]: + c2 = torch.tensor([0.0] * len(x), requires_grad=True) + else: + c2 = ( + (t[i + k + 1] - x) + / (t[i + k + 1] - t[i + 1]) + * self.basis(x, k - 1, i + 1, t) ) - # Compute the two terms of the recursion - term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1] - term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:] - - # Combine terms to get the new basis - basis = term1 + term2 - - return basis - - def forward(self, x): - """ - Forward pass for the :class:`Spline` model. - - :param x: The input tensor. - :type x: torch.Tensor | LabelTensor - :return: The output tensor. - :rtype: torch.Tensor - """ - return torch.einsum( - "...bi, i -> ...b", - self.basis(x.as_subclass(torch.Tensor)).squeeze(-1), - self.control_points, - ) + return c1 + c2 @property def control_points(self): @@ -251,34 +116,24 @@ def control_points(self): return self._control_points @control_points.setter - def control_points(self, control_points): + def control_points(self, value): """ Set the control points of the spline. - :param torch.Tensor control_points: The control points tensor. If None, - control points are initialized to learnable parameters with zero - initial value. Default is None. - :raises ValueError: If there are not enough knots to define the control - points, due to the relation: #knots = order + #control_points. + :param value: The control points. + :type value: torch.Tensor | dict + :raises ValueError: If invalid value is passed. """ - # If control points are not provided, initialize them - if control_points is None: - - # Check that there are enough knots to define control points - if len(self.knots) < self.order + 1: - raise ValueError( - f"Not enough knots to define control points. Got " - f"{len(self.knots)} knots, but need at least " - f"{self.order + 1}." - ) - - # Initialize control points to zero - control_points = torch.zeros(len(self.knots) - self.order) + if isinstance(value, dict): + if "n" not in value: + raise ValueError("Invalid value for control_points") + n = value["n"] + dim = value.get("dim", 1) + value = torch.zeros(n, dim) - # Set control points - self._control_points = torch.nn.Parameter( - control_points, requires_grad=True - ) + if not isinstance(value, torch.Tensor): + raise ValueError("Invalid value for control_points") + self._control_points = torch.nn.Parameter(value, requires_grad=True) @property def knots(self): @@ -295,72 +150,50 @@ def knots(self, value): """ Set the knots of the spline. - :param value: The knots of the spline. If a tensor is provided, knots - are set directly from the tensor. If a dictionary is provided, it - must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``. - Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"`` - define the interval, and ``"mode"`` selects the sampling strategy. - The supported modes are ``"uniform"``, where the knots are evenly - spaced over :math:`[min, max]`, and ``"auto"``, where knots are - constructed to ensure that the spline interpolates the first and - last control points. In this case, the number of knots is inferred - and the ``"n"`` key is ignored. + :param value: The knots. :type value: torch.Tensor | dict - :raises ValueError: If a dictionary is provided but does not contain - the required keys. - :raises ValueError: If the mode specified in the dictionary is invalid. + :raises ValueError: If invalid value is passed. """ - # If a dictionary is provided, initialize knots accordingly if isinstance(value, dict): - # Check that required keys are present - required_keys = {"n", "min", "max", "mode"} - if not required_keys.issubset(value.keys()): - raise ValueError( - f"When providing knots as a dictionary, the following " - f"keys must be present: {required_keys}. Got " - f"{value.keys()}." - ) + type_ = value.get("type", "auto") + min_ = value.get("min", 0) + max_ = value.get("max", 1) + n = value.get("n", 10) + + if type_ == "uniform": + value = torch.linspace(min_, max_, n + self.k + 1) + elif type_ == "auto": + initial_knots = torch.ones(self.order + 1) * min_ + final_knots = torch.ones(self.order + 1) * max_ + + if n < self.order + 1: + value = torch.concatenate((initial_knots, final_knots)) + elif n - 2 * self.order + 1 == 1: + value = torch.Tensor([(max_ + min_) / 2]) + else: + value = torch.linspace(min_, max_, n - 2 * self.order - 1) - # Uniform sampling of knots - if value["mode"] == "uniform": - value = torch.linspace(value["min"], value["max"], value["n"]) + value = torch.concatenate((initial_knots, value, final_knots)) - # Automatic sampling of interpolating knots - elif value["mode"] == "auto": + if not isinstance(value, torch.Tensor): + raise ValueError("Invalid value for knots") - # Repeat the first and last knots 'order' times - initial_knots = torch.ones(self.order) * value["min"] - final_knots = torch.ones(self.order) * value["max"] + self._knots = value - # Number of internal knots - n_internal = value["n"] - 2 * self.order + def forward(self, x): + """ + Forward pass for the :class:`Spline` model. - # If no internal knots are needed, just concatenate boundaries - if n_internal <= 0: - value = torch.cat((initial_knots, final_knots)) + :param torch.Tensor x: The input tensor. + :return: The output tensor. + :rtype: torch.Tensor + """ + t = self.knots + k = self.k + c = self.control_points - # Else, sample internal knots uniformly and exclude boundaries - # Recover the correct number of internal knots when slicing by - # adding 2 to n_internal - else: - internal_knots = torch.linspace( - value["min"], value["max"], n_internal + 2 - )[1:-1] - value = torch.cat( - (initial_knots, internal_knots, final_knots) - ) - - # Raise error if mode is invalid - else: - raise ValueError( - f"Invalid mode for knots initialization. Got " - f"{value['mode']}, but expected 'uniform' or 'auto'." - ) - - # Set knots - self.register_buffer("_knots", value.sort(dim=0).values) - - # Recompute boundary interval when knots change - if hasattr(self, "_boundary_interval_idx"): - self._boundary_interval_idx = self._compute_boundary_interval() + basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c))) + y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) + + return y diff --git a/pina/model/spline_surface.py b/pina/model/spline_surface.py deleted file mode 100644 index 30d41bbde..000000000 --- a/pina/model/spline_surface.py +++ /dev/null @@ -1,212 +0,0 @@ -"""Module for the bivariate B-Spline surface model class.""" - -import torch -from .spline import Spline -from ..utils import check_consistency - - -class SplineSurface(torch.nn.Module): - r""" - The bivariate B-Spline surface model class. - - A bivariate B-spline surface is a parametric surface defined as the tensor - product of two univariate B-spline curves: - - .. math:: - - S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j}, - \quad x \in [x_1, x_m], y \in [y_1, y_l] - - where: - - - :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed - points influence the shape of the surface but are not generally - interpolated, except at the boundaries under certain knot multiplicities. - - :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions - defined over two orthogonal directions, with orders :math:`k` and - :math:`s`, respectively. - - :math:`X = \{ x_1, x_2, \dots, x_m \}` and - :math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot - vectors along the two directions. - """ - - def __init__(self, orders, knots_u=None, knots_v=None, control_points=None): - """ - Initialization of the :class:`SplineSurface` class. - - :param list[int] orders: The orders of the spline along each parametric - direction. Each order defines the degree of the corresponding basis - as ``degree = order - 1``. - :param knots_u: The knots of the spline along the first direction. - For details on valid formats and initialization modes, see the - :class:`Spline` class. Default is None. - :type knots_u: torch.Tensor | dict - :param knots_v: The knots of the spline along the second direction. - For details on valid formats and initialization modes, see the - :class:`Spline` class. Default is None. - :type knots_v: torch.Tensor | dict - :param torch.Tensor control_points: The control points defining the - surface geometry. It must be a two-dimensional tensor of shape - ``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``. - If None, they are initialized as learnable parameters with zero - values. Default is None. - :raises ValueError: If ``orders`` is not a list of integers. - :raises ValueError: If ``knots_u`` is neither a torch.Tensor nor a - dictionary, when provided. - :raises ValueError: If ``knots_v`` is neither a torch.Tensor nor a - dictionary, when provided. - :raises ValueError: If ``control_points`` is not a torch.Tensor, - when provided. - :raises ValueError: If ``orders`` is not a list of two elements. - :raises ValueError: If ``knots_u``, ``knots_v``, and ``control_points`` - are all None. - """ - super().__init__() - - # Check consistency - check_consistency(orders, int) - check_consistency(control_points, (type(None), torch.Tensor)) - check_consistency(knots_u, (type(None), torch.Tensor, dict)) - check_consistency(knots_v, (type(None), torch.Tensor, dict)) - - # Check orders is a list of two elements - if len(orders) != 2: - raise ValueError("orders must be a list of two elements.") - - # Raise error if neither knots nor control points are provided - if (knots_u is None or knots_v is None) and control_points is None: - raise ValueError( - "control_points cannot be None if knots_u or knots_v is None." - ) - - # Initialize knots_u if not provided - if knots_u is None and control_points is not None: - knots_u = { - "n": control_points.shape[0] + orders[0], - "min": 0, - "max": 1, - "mode": "auto", - } - - # Initialize knots_v if not provided - if knots_v is None and control_points is not None: - knots_v = { - "n": control_points.shape[1] + orders[1], - "min": 0, - "max": 1, - "mode": "auto", - } - - # Create two univariate b-splines - self.spline_u = Spline(order=orders[0], knots=knots_u) - self.spline_v = Spline(order=orders[1], knots=knots_v) - self.control_points = control_points - - # Delete unneeded parameters - delattr(self.spline_u, "_control_points") - delattr(self.spline_v, "_control_points") - - def forward(self, x): - """ - Forward pass for the :class:`SplineSurface` model. - - :param x: The input tensor. - :type x: torch.Tensor | LabelTensor - :return: The output tensor. - :rtype: torch.Tensor - """ - return torch.einsum( - "...bi, ...bj, ij -> ...b", - self.spline_u.basis(x.as_subclass(torch.Tensor)[..., 0]), - self.spline_v.basis(x.as_subclass(torch.Tensor)[..., 1]), - self.control_points, - ).unsqueeze(-1) - - @property - def knots(self): - """ - The knots of the univariate splines defining the spline surface. - - :return: The knots. - :rtype: tuple(torch.Tensor, torch.Tensor) - """ - return self.spline_u.knots, self.spline_v.knots - - @knots.setter - def knots(self, value): - """ - Set the knots of the spline surface. - - :param value: A tuple (knots_u, knots_v) containing the knots for both - parametric directions. - :type value: tuple(torch.Tensor | dict, torch.Tensor | dict) - :raises ValueError: If value is not a tuple of two elements. - """ - # Check value is a tuple of two elements - if not (isinstance(value, tuple) and len(value) == 2): - raise ValueError("Knots must be a tuple of two elements.") - - knots_u, knots_v = value - self.spline_u.knots = knots_u - self.spline_v.knots = knots_v - - @property - def control_points(self): - """ - The control points of the spline. - - :return: The control points. - :rtype: torch.Tensor - """ - return self._control_points - - @control_points.setter - def control_points(self, control_points): - """ - Set the control points of the spline surface. - - :param torch.Tensor control_points: The bidimensional control points - tensor, where each dimension refers to a direction in the parameter - space. If None, control points are initialized to learnable - parameters with zero initial value. Default is None. - :raises ValueError: If in any direction there are not enough knots to - define the control points, due to the relation: - #knots = order + #control_points. - :raises ValueError: If ``control_points`` is not of the correct shape. - """ - # Save correct shape of control points - __valid_shape = ( - len(self.spline_u.knots) - self.spline_u.order, - len(self.spline_v.knots) - self.spline_v.order, - ) - - # If control points are not provided, initialize them - if control_points is None: - - # Check that there are enough knots to define control points - if ( - len(self.spline_u.knots) < self.spline_u.order + 1 - or len(self.spline_v.knots) < self.spline_v.order + 1 - ): - raise ValueError( - f"Not enough knots to define control points. Got " - f"{len(self.spline_u.knots)} knots along u and " - f"{len(self.spline_v.knots)} knots along v, but need at " - f"least {self.spline_u.order + 1} and " - f"{self.spline_v.order + 1}, respectively." - ) - - # Initialize control points to zero - control_points = torch.zeros(__valid_shape) - - # Check control points - if control_points.shape != __valid_shape: - raise ValueError( - "control_points must be of the correct shape. ", - f"Expected {__valid_shape}, got {control_points.shape}.", - ) - - # Register control points as a learnable parameter - self._control_points = torch.nn.Parameter( - control_points, requires_grad=True - ) diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index d22de9f26..d38b1610b 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -1,175 +1,81 @@ import torch import pytest -from scipy.interpolate import BSpline -from pina.model import Spline -from pina import LabelTensor +from pina.model import Spline -# Utility quantities for testing -order = torch.randint(1, 8, (1,)).item() -n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() -n_knots = order + n_ctrl_pts +data = torch.rand((20, 3)) +input_vars = 3 +output_vars = 4 -# Input tensor -points = [ - LabelTensor(torch.rand(100, 1), ["x"]), - LabelTensor(torch.rand(2, 100, 1), ["x"]), +valid_args = [ + { + "knots": torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0]), + "control_points": torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]), + "order": 3, + }, + { + "knots": torch.tensor( + [-2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 1.0, 2.0, 2.0, 2.0, 2.0] + ), + "control_points": torch.tensor([0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]), + "order": 4, + }, + # {'control_points': {'n': 5, 'dim': 1}, 'order': 2}, + # {'control_points': {'n': 7, 'dim': 1}, 'order': 3} ] -# Function to compare with scipy implementation -def check_scipy_spline(model, x, output_): +def scipy_check(model, x, y): + from scipy.interpolate._bsplines import BSpline + import numpy as np - # Define scipy spline - scipy_spline = BSpline( + spline = BSpline( t=model.knots.detach().numpy(), c=model.control_points.detach().numpy(), k=model.order - 1, ) - - # Compare outputs - torch.allclose( - output_, - torch.tensor(scipy_spline(x), dtype=output_.dtype), - atol=1e-5, - rtol=1e-5, - ) - - -# Define all possible combinations of valid arguments for Spline class -valid_args = [ - { - "order": order, - "control_points": torch.rand(n_ctrl_pts), - "knots": torch.linspace(0, 1, n_knots), - }, - { - "order": order, - "control_points": torch.rand(n_ctrl_pts), - "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, - }, - { - "order": order, - "control_points": torch.rand(n_ctrl_pts), - "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, - }, - { - "order": order, - "control_points": None, - "knots": torch.linspace(0, 1, n_knots), - }, - { - "order": order, - "control_points": None, - "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, - }, - { - "order": order, - "control_points": None, - "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, - }, - { - "order": order, - "control_points": torch.rand(n_ctrl_pts), - "knots": None, - }, -] + y_scipy = spline(x).flatten() + y = y.detach().numpy() + np.testing.assert_allclose(y, y_scipy, atol=1e-5) @pytest.mark.parametrize("args", valid_args) def test_constructor(args): Spline(**args) - # Should fail if order is not a positive integer - with pytest.raises(AssertionError): - Spline( - order=-1, control_points=args["control_points"], knots=args["knots"] - ) - - # Should fail if control_points is not None or a torch.Tensor - with pytest.raises(ValueError): - Spline( - order=args["order"], control_points=[1, 2, 3], knots=args["knots"] - ) - - # Should fail if knots is not None, a torch.Tensor, or a dict - with pytest.raises(ValueError): - Spline( - order=args["order"], control_points=args["control_points"], knots=5 - ) - - # Should fail if both knots and control_points are None - with pytest.raises(ValueError): - Spline(order=args["order"], control_points=None, knots=None) - - # Should fail if knots is not one-dimensional - with pytest.raises(ValueError): - Spline( - order=args["order"], - control_points=args["control_points"], - knots=torch.rand(n_knots, 4), - ) - - # Should fail if control_points is not one-dimensional - with pytest.raises(ValueError): - Spline( - order=args["order"], - control_points=torch.rand(n_ctrl_pts, 4), - knots=args["knots"], - ) - - # Should fail if the number of knots != order + number of control points - # If control points are None, they are initialized to fulfill this condition - if args["control_points"] is not None: - with pytest.raises(ValueError): - Spline( - order=args["order"], - control_points=args["control_points"], - knots=torch.linspace(0, 1, n_knots + 1), - ) - - # Should fail if the knot dict is missing required keys - with pytest.raises(ValueError): - Spline( - order=args["order"], - control_points=args["control_points"], - knots={"n": n_knots, "min": 0, "max": 1}, - ) - # Should fail if the knot dict has invalid 'mode' key +def test_constructor_wrong(): with pytest.raises(ValueError): - Spline( - order=args["order"], - control_points=args["control_points"], - knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"}, - ) + Spline() @pytest.mark.parametrize("args", valid_args) -@pytest.mark.parametrize("pts", points) -def test_forward(args, pts): - - # Define the model +def test_forward(args): + min_x = args["knots"][0] + max_x = args["knots"][-1] + xi = torch.linspace(min_x, max_x, 1000) model = Spline(**args) - - # Evaluate the model - output_ = model(pts) - assert output_.shape == pts.shape - - # Compare with scipy implementation only for interpolant knots (mode: auto) - if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto": - check_scipy_spline(model, pts, output_) + yi = model(xi).squeeze() + scipy_check(model, xi, yi) + return @pytest.mark.parametrize("args", valid_args) -@pytest.mark.parametrize("pts", points) -def test_backward(args, pts): - - # Define the model +def test_backward(args): + min_x = args["knots"][0] + max_x = args["knots"][-1] + xi = torch.linspace(min_x, max_x, 100) model = Spline(**args) - - # Evaluate the model - output_ = model(pts) - loss = torch.mean(output_) - loss.backward() - assert model.control_points.grad.shape == model.control_points.shape + yi = model(xi) + fake_loss = torch.sum(yi) + assert model.control_points.grad is None + fake_loss.backward() + assert model.control_points.grad is not None + + # dim_in, dim_out = 3, 2 + # fnn = FeedForward(dim_in, dim_out) + # data.requires_grad = True + # output_ = fnn(data) + # l=torch.mean(output_) + # l.backward() + # assert data._grad.shape == torch.Size([20,3]) diff --git a/tests/test_model/test_spline_surface.py b/tests/test_model/test_spline_surface.py deleted file mode 100644 index feab587b5..000000000 --- a/tests/test_model/test_spline_surface.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch -import random -import pytest -from pina.model import SplineSurface -from pina import LabelTensor - - -# Utility quantities for testing -orders = [random.randint(1, 8) for _ in range(2)] -n_ctrl_pts = random.randint(max(orders), max(orders) + 5) -n_knots = [orders[i] + n_ctrl_pts for i in range(2)] - -# Input tensor -points = [ - LabelTensor(torch.rand(100, 2), ["x", "y"]), - LabelTensor(torch.rand(2, 100, 2), ["x", "y"]), -] - - -@pytest.mark.parametrize( - "knots_u", - [ - torch.rand(n_knots[0]), - {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, - {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, - None, - ], -) -@pytest.mark.parametrize( - "knots_v", - [ - torch.rand(n_knots[1]), - {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, - {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, - None, - ], -) -@pytest.mark.parametrize( - "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] -) -def test_constructor(knots_u, knots_v, control_points): - - # Skip if knots_u, knots_v, and control_points are all None - if (knots_u is None or knots_v is None) and control_points is None: - return - - SplineSurface( - orders=orders, - knots_u=knots_u, - knots_v=knots_v, - control_points=control_points, - ) - - # Should fail if orders is not list of two elements - with pytest.raises(ValueError): - SplineSurface( - orders=[orders[0]], - knots_u=knots_u, - knots_v=knots_v, - control_points=control_points, - ) - - # Should fail if both knots and control_points are None - with pytest.raises(ValueError): - SplineSurface( - orders=orders, - knots_u=None, - knots_v=None, - control_points=None, - ) - - # Should fail if control_points is not a torch.Tensor when provided - with pytest.raises(ValueError): - SplineSurface( - orders=orders, - knots_u=knots_u, - knots_v=knots_v, - control_points=[[0.0] * n_ctrl_pts] * n_ctrl_pts, - ) - - # Should fail if control_points is not of the correct shape when provided - # It assumes that at least one among knots_u and knots_v is not None - if knots_u is not None or knots_v is not None: - with pytest.raises(ValueError): - SplineSurface( - orders=orders, - knots_u=knots_u, - knots_v=knots_v, - control_points=torch.rand(n_ctrl_pts + 1, n_ctrl_pts + 1), - ) - - # Should fail if there are not enough knots_u to define the control points - with pytest.raises(ValueError): - SplineSurface( - orders=orders, - knots_u=torch.linspace(0, 1, orders[0]), - knots_v=knots_v, - control_points=None, - ) - - # Should fail if there are not enough knots_v to define the control points - with pytest.raises(ValueError): - SplineSurface( - orders=orders, - knots_u=knots_u, - knots_v=torch.linspace(0, 1, orders[1]), - control_points=None, - ) - - -@pytest.mark.parametrize( - "knots_u", - [ - torch.rand(n_knots[0]), - {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, - {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, - ], -) -@pytest.mark.parametrize( - "knots_v", - [ - torch.rand(n_knots[1]), - {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, - {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, - ], -) -@pytest.mark.parametrize( - "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] -) -@pytest.mark.parametrize("pts", points) -def test_forward(knots_u, knots_v, control_points, pts): - - # Define the model - model = SplineSurface( - orders=orders, - knots_u=knots_u, - knots_v=knots_v, - control_points=control_points, - ) - - # Evaluate the model - output_ = model(pts) - assert output_.shape == (*pts.shape[:-1], 1) - - -@pytest.mark.parametrize( - "knots_u", - [ - torch.rand(n_knots[0]), - {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, - {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, - ], -) -@pytest.mark.parametrize( - "knots_v", - [ - torch.rand(n_knots[1]), - {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, - {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, - ], -) -@pytest.mark.parametrize( - "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] -) -@pytest.mark.parametrize("pts", points) -def test_backward(knots_u, knots_v, control_points, pts): - - # Define the model - model = SplineSurface( - orders=orders, - knots_u=knots_u, - knots_v=knots_v, - control_points=control_points, - ) - - # Evaluate the model - output_ = model(pts) - loss = torch.mean(output_) - loss.backward() - assert model.control_points.grad.shape == model.control_points.shape