|
| 1 | +"""Module for the bivariate B-Spline surface model class.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from .spline import Spline |
| 5 | +from ..utils import check_consistency |
| 6 | + |
| 7 | + |
| 8 | +class SplineSurface(torch.nn.Module): |
| 9 | + r""" |
| 10 | + The bivariate B-Spline surface model class. |
| 11 | +
|
| 12 | + A bivariate B-spline surface is a parametric surface defined as the tensor |
| 13 | + product of two univariate B-spline curves: |
| 14 | +
|
| 15 | + .. math:: |
| 16 | +
|
| 17 | + S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j}, |
| 18 | + \quad x \in [x_1, x_m], y \in [y_1, y_l] |
| 19 | +
|
| 20 | + where: |
| 21 | +
|
| 22 | + - :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed |
| 23 | + points influence the shape of the surface but are not generally |
| 24 | + interpolated, except at the boundaries under certain knot multiplicities. |
| 25 | + - :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions |
| 26 | + defined over two orthogonal directions, with orders :math:`k` and |
| 27 | + :math:`s`, respectively. |
| 28 | + - :math:`X = \{ x_1, x_2, \dots, x_m \}` and |
| 29 | + :math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot |
| 30 | + vectors along the two directions. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, knots_u, knots_v, orders=[4, 4], control_points=None): |
| 34 | + """ |
| 35 | + Initialization of the :class:`SplineSurface` class. |
| 36 | +
|
| 37 | + :param knots_u: The knots of the spline along the first direction. |
| 38 | + Unlike the univariate case, this must be explicitly provided. |
| 39 | + For details on valid formats and initialization modes, see the |
| 40 | + :class:`Spline` class. |
| 41 | + :type knots_u: torch.Tensor | dict |
| 42 | + :param knots_v: The knots of the spline along the second direction. |
| 43 | + Unlike the univariate case, this must be explicitly provided. |
| 44 | + For details on valid formats and initialization modes, see the |
| 45 | + :class:`Spline` class. |
| 46 | + :type knots_v: torch.Tensor | dict |
| 47 | + :param list[int] orders: The orders of the spline along each parametric |
| 48 | + direction. Each order defines the degree of the corresponding basis |
| 49 | + as ``degree = order - 1``. Default is ``[4, 4]``. |
| 50 | + :param torch.Tensor control_points: The control points defining the |
| 51 | + surface geometry. It must be a two-dimensional tensor of shape |
| 52 | + ``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``. |
| 53 | + If None, they are initialized as learnable parameters with zero |
| 54 | + values. Default is None. |
| 55 | + :raises ValueError: If ``orders`` is not a list of two elements. |
| 56 | + :raises ValueError: If ``knots_u`` or ``knots_v`` is None. |
| 57 | + :raises ValueError: If ``control_points`` is not a torch.Tensor when |
| 58 | + provided. |
| 59 | + :raises ValueError: If ``control_points`` is not of the correct shape |
| 60 | + when provided. |
| 61 | + """ |
| 62 | + super().__init__() |
| 63 | + |
| 64 | + # Check consistency |
| 65 | + check_consistency(orders, int) |
| 66 | + check_consistency(control_points, (type(None), torch.Tensor)) |
| 67 | + |
| 68 | + # Check orders is a list of two elements |
| 69 | + if len(orders) != 2: |
| 70 | + raise ValueError("orders must be a list of two elements.") |
| 71 | + |
| 72 | + # Check knots_u and knots_v are not None |
| 73 | + if knots_u is None or knots_v is None: |
| 74 | + raise ValueError("knots_u and knots_v must cannot be None.") |
| 75 | + |
| 76 | + # Create two univariate b-splines |
| 77 | + self.spline_u = Spline(order=orders[0], knots=knots_u) |
| 78 | + self.spline_v = Spline(order=orders[1], knots=knots_v) |
| 79 | + |
| 80 | + # Delete unneeded parameters |
| 81 | + delattr(self.spline_u, "_control_points") |
| 82 | + delattr(self.spline_v, "_control_points") |
| 83 | + |
| 84 | + # Save correct shape of control points |
| 85 | + __valid_shape = ( |
| 86 | + len(self.spline_u.knots) - self.spline_u.order, |
| 87 | + len(self.spline_v.knots) - self.spline_v.order, |
| 88 | + ) |
| 89 | + |
| 90 | + # Initialize control points, if not provided |
| 91 | + if control_points is None: |
| 92 | + control_points = torch.zeros(__valid_shape) |
| 93 | + |
| 94 | + # Check control points |
| 95 | + if control_points.shape != __valid_shape: |
| 96 | + raise ValueError( |
| 97 | + "control_points must be of the correct shape. ", |
| 98 | + f"Expected {__valid_shape}, got {control_points.shape}.", |
| 99 | + ) |
| 100 | + |
| 101 | + # Register control points as a learnable parameter |
| 102 | + self._control_points = torch.nn.Parameter( |
| 103 | + control_points, requires_grad=True |
| 104 | + ) |
| 105 | + |
| 106 | + def forward(self, x): |
| 107 | + """ |
| 108 | + Forward pass for the :class:`SplineSurface` model. |
| 109 | +
|
| 110 | + :param x: The input tensor. |
| 111 | + :type x: torch.Tensor | LabelTensor |
| 112 | + :return: The output tensor. |
| 113 | + :rtype: torch.Tensor |
| 114 | + """ |
| 115 | + return torch.einsum( |
| 116 | + "bi, bj, ij -> b", |
| 117 | + self.spline_u.basis(x.as_subclass(torch.Tensor)[:, 0]), |
| 118 | + self.spline_v.basis(x.as_subclass(torch.Tensor)[:, 1]), |
| 119 | + self.control_points, |
| 120 | + ).reshape(-1, 1) |
| 121 | + |
| 122 | + @property |
| 123 | + def knots(self): |
| 124 | + """ |
| 125 | + The knots of the univariate splines defining the spline surface. |
| 126 | +
|
| 127 | + :return: The knots. |
| 128 | + :rtype: tuple(torch.Tensor, torch.Tensor) |
| 129 | + """ |
| 130 | + return self.spline_u.knots, self.spline_v.knots |
| 131 | + |
| 132 | + @property |
| 133 | + def control_points(self): |
| 134 | + """ |
| 135 | + The control points of the spline. |
| 136 | +
|
| 137 | + :return: The control points. |
| 138 | + :rtype: torch.Tensor |
| 139 | + """ |
| 140 | + return self._control_points |
0 commit comments