|
| 1 | +"""Module for the bivariate B-Spline surface model class.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from .spline import Spline |
| 5 | +from ..utils import check_consistency |
| 6 | + |
| 7 | + |
| 8 | +class SplineSurface(torch.nn.Module): |
| 9 | + r""" |
| 10 | + The bivariate B-Spline surface model class. |
| 11 | +
|
| 12 | + A bivariate B-spline surface is a parametric surface defined as the tensor |
| 13 | + product of two univariate B-spline curves: |
| 14 | +
|
| 15 | + .. math:: |
| 16 | +
|
| 17 | + S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j}, |
| 18 | + \quad x \in [x_1, x_m], y \in [y_1, y_l] |
| 19 | +
|
| 20 | + where: |
| 21 | +
|
| 22 | + - :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed |
| 23 | + points influence the shape of the surface but are not generally |
| 24 | + interpolated, except at the boundaries under certain knot multiplicities. |
| 25 | + - :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions |
| 26 | + defined over two orthogonal directions, with orders :math:`k` and |
| 27 | + :math:`s`, respectively. |
| 28 | + - :math:`X = \{ x_1, x_2, \dots, x_m \}` and |
| 29 | + :math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot |
| 30 | + vectors along the two directions. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, orders, knots_u=None, knots_v=None, control_points=None): |
| 34 | + """ |
| 35 | + Initialization of the :class:`SplineSurface` class. |
| 36 | +
|
| 37 | + :param list[int] orders: The orders of the spline along each parametric |
| 38 | + direction. Each order defines the degree of the corresponding basis |
| 39 | + as ``degree = order - 1``. |
| 40 | + :param knots_u: The knots of the spline along the first direction. |
| 41 | + For details on valid formats and initialization modes, see the |
| 42 | + :class:`Spline` class. Default is None. |
| 43 | + :type knots_u: torch.Tensor | dict |
| 44 | + :param knots_v: The knots of the spline along the second direction. |
| 45 | + For details on valid formats and initialization modes, see the |
| 46 | + :class:`Spline` class. Default is None. |
| 47 | + :type knots_v: torch.Tensor | dict |
| 48 | + :param torch.Tensor control_points: The control points defining the |
| 49 | + surface geometry. It must be a two-dimensional tensor of shape |
| 50 | + ``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``. |
| 51 | + If None, they are initialized as learnable parameters with zero |
| 52 | + values. Default is None. |
| 53 | + :raises ValueError: If ``orders`` is not a list of integers. |
| 54 | + :raises ValueError: If ``knots_u`` is neither a torch.Tensor nor a |
| 55 | + dictionary, when provided. |
| 56 | + :raises ValueError: If ``knots_v`` is neither a torch.Tensor nor a |
| 57 | + dictionary, when provided. |
| 58 | + :raises ValueError: If ``control_points`` is not a torch.Tensor, |
| 59 | + when provided. |
| 60 | + :raises ValueError: If ``orders`` is not a list of two elements. |
| 61 | + :raises ValueError: If ``knots_u``, ``knots_v``, and ``control_points`` |
| 62 | + are all None. |
| 63 | + """ |
| 64 | + super().__init__() |
| 65 | + |
| 66 | + # Check consistency |
| 67 | + check_consistency(orders, int) |
| 68 | + check_consistency(control_points, (type(None), torch.Tensor)) |
| 69 | + check_consistency(knots_u, (type(None), torch.Tensor, dict)) |
| 70 | + check_consistency(knots_v, (type(None), torch.Tensor, dict)) |
| 71 | + |
| 72 | + # Check orders is a list of two elements |
| 73 | + if len(orders) != 2: |
| 74 | + raise ValueError("orders must be a list of two elements.") |
| 75 | + |
| 76 | + # Raise error if neither knots nor control points are provided |
| 77 | + if (knots_u is None or knots_v is None) and control_points is None: |
| 78 | + raise ValueError( |
| 79 | + "control_points cannot be None if knots_u or knots_v is None." |
| 80 | + ) |
| 81 | + |
| 82 | + # Initialize knots_u if not provided |
| 83 | + if knots_u is None and control_points is not None: |
| 84 | + knots_u = { |
| 85 | + "n": control_points.shape[0] + orders[0], |
| 86 | + "min": 0, |
| 87 | + "max": 1, |
| 88 | + "mode": "auto", |
| 89 | + } |
| 90 | + |
| 91 | + # Initialize knots_v if not provided |
| 92 | + if knots_v is None and control_points is not None: |
| 93 | + knots_v = { |
| 94 | + "n": control_points.shape[1] + orders[1], |
| 95 | + "min": 0, |
| 96 | + "max": 1, |
| 97 | + "mode": "auto", |
| 98 | + } |
| 99 | + |
| 100 | + # Create two univariate b-splines |
| 101 | + self.spline_u = Spline(order=orders[0], knots=knots_u) |
| 102 | + self.spline_v = Spline(order=orders[1], knots=knots_v) |
| 103 | + self.control_points = control_points |
| 104 | + |
| 105 | + # Delete unneeded parameters |
| 106 | + delattr(self.spline_u, "_control_points") |
| 107 | + delattr(self.spline_v, "_control_points") |
| 108 | + |
| 109 | + def forward(self, x): |
| 110 | + """ |
| 111 | + Forward pass for the :class:`SplineSurface` model. |
| 112 | +
|
| 113 | + :param x: The input tensor. |
| 114 | + :type x: torch.Tensor | LabelTensor |
| 115 | + :return: The output tensor. |
| 116 | + :rtype: torch.Tensor |
| 117 | + """ |
| 118 | + return torch.einsum( |
| 119 | + "...bi, ...bj, ij -> ...b", |
| 120 | + self.spline_u.basis(x.as_subclass(torch.Tensor)[..., 0]), |
| 121 | + self.spline_v.basis(x.as_subclass(torch.Tensor)[..., 1]), |
| 122 | + self.control_points, |
| 123 | + ).unsqueeze(-1) |
| 124 | + |
| 125 | + @property |
| 126 | + def knots(self): |
| 127 | + """ |
| 128 | + The knots of the univariate splines defining the spline surface. |
| 129 | +
|
| 130 | + :return: The knots. |
| 131 | + :rtype: tuple(torch.Tensor, torch.Tensor) |
| 132 | + """ |
| 133 | + return self.spline_u.knots, self.spline_v.knots |
| 134 | + |
| 135 | + @knots.setter |
| 136 | + def knots(self, value): |
| 137 | + """ |
| 138 | + Set the knots of the spline surface. |
| 139 | +
|
| 140 | + :param value: A tuple (knots_u, knots_v) containing the knots for both |
| 141 | + parametric directions. |
| 142 | + :type value: tuple(torch.Tensor | dict, torch.Tensor | dict) |
| 143 | + :raises ValueError: If value is not a tuple of two elements. |
| 144 | + """ |
| 145 | + # Check value is a tuple of two elements |
| 146 | + if not (isinstance(value, tuple) and len(value) == 2): |
| 147 | + raise ValueError("Knots must be a tuple of two elements.") |
| 148 | + |
| 149 | + knots_u, knots_v = value |
| 150 | + self.spline_u.knots = knots_u |
| 151 | + self.spline_v.knots = knots_v |
| 152 | + |
| 153 | + @property |
| 154 | + def control_points(self): |
| 155 | + """ |
| 156 | + The control points of the spline. |
| 157 | +
|
| 158 | + :return: The control points. |
| 159 | + :rtype: torch.Tensor |
| 160 | + """ |
| 161 | + return self._control_points |
| 162 | + |
| 163 | + @control_points.setter |
| 164 | + def control_points(self, control_points): |
| 165 | + """ |
| 166 | + Set the control points of the spline surface. |
| 167 | +
|
| 168 | + :param torch.Tensor control_points: The bidimensional control points |
| 169 | + tensor, where each dimension refers to a direction in the parameter |
| 170 | + space. If None, control points are initialized to learnable |
| 171 | + parameters with zero initial value. Default is None. |
| 172 | + :raises ValueError: If in any direction there are not enough knots to |
| 173 | + define the control points, due to the relation: |
| 174 | + #knots = order + #control_points. |
| 175 | + :raises ValueError: If ``control_points`` is not of the correct shape. |
| 176 | + """ |
| 177 | + # Save correct shape of control points |
| 178 | + __valid_shape = ( |
| 179 | + len(self.spline_u.knots) - self.spline_u.order, |
| 180 | + len(self.spline_v.knots) - self.spline_v.order, |
| 181 | + ) |
| 182 | + |
| 183 | + # If control points are not provided, initialize them |
| 184 | + if control_points is None: |
| 185 | + |
| 186 | + # Check that there are enough knots to define control points |
| 187 | + if ( |
| 188 | + len(self.spline_u.knots) < self.spline_u.order + 1 |
| 189 | + or len(self.spline_v.knots) < self.spline_v.order + 1 |
| 190 | + ): |
| 191 | + raise ValueError( |
| 192 | + f"Not enough knots to define control points. Got " |
| 193 | + f"{len(self.spline_u.knots)} knots along u and " |
| 194 | + f"{len(self.spline_v.knots)} knots along v, but need at " |
| 195 | + f"least {self.spline_u.order + 1} and " |
| 196 | + f"{self.spline_v.order + 1}, respectively." |
| 197 | + ) |
| 198 | + |
| 199 | + # Initialize control points to zero |
| 200 | + control_points = torch.zeros(__valid_shape) |
| 201 | + |
| 202 | + # Check control points |
| 203 | + if control_points.shape != __valid_shape: |
| 204 | + raise ValueError( |
| 205 | + "control_points must be of the correct shape. ", |
| 206 | + f"Expected {__valid_shape}, got {control_points.shape}.", |
| 207 | + ) |
| 208 | + |
| 209 | + # Register control points as a learnable parameter |
| 210 | + self._control_points = torch.nn.Parameter( |
| 211 | + control_points, requires_grad=True |
| 212 | + ) |
0 commit comments