Skip to content

Commit df4ea64

Browse files
add b-spline surface
1 parent 71ce8c5 commit df4ea64

File tree

7 files changed

+425
-30
lines changed

7 files changed

+425
-30
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Models
9595
MultiFeedForward <model/multi_feed_forward.rst>
9696
ResidualFeedForward <model/residual_feed_forward.rst>
9797
Spline <model/spline.rst>
98+
SplineSurface <model/spline_surface.rst>
9899
DeepONet <model/deeponet.rst>
99100
MIONet <model/mionet.rst>
100101
KernelNeuralOperator <model/kernel_neural_operator.rst>
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Spline Surface
2+
================
3+
.. currentmodule:: pina.model.spline_surface
4+
5+
.. autoclass:: SplineSurface
6+
:members:
7+
:show-inheritance:

pina/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .average_neural_operator import AveragingNeuralOperator
2727
from .low_rank_neural_operator import LowRankNeuralOperator
2828
from .spline import Spline
29+
from .spline_surface import SplineSurface
2930
from .graph_neural_operator import GraphNeuralOperator
3031
from .pirate_network import PirateNet
3132
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator

pina/model/spline.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Module for the B-Spline model class."""
22

3-
import torch
43
import warnings
5-
from ..utils import check_positive_integer
4+
import torch
5+
from ..utils import check_positive_integer, check_consistency
66

77

88
class Spline(torch.nn.Module):
@@ -75,6 +75,10 @@ def __init__(self, order=4, knots=None, control_points=None):
7575
If None, they are initialized as learnable parameters with an
7676
initial value of zero. Default is None.
7777
:raises AssertionError: If ``order`` is not a positive integer.
78+
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
79+
dictionary, when provided.
80+
:raises ValueError: If ``control_points`` is not a torch.Tensor,
81+
when provided.
7882
:raises ValueError: If both ``knots`` and ``control_points`` are None.
7983
:raises ValueError: If ``knots`` is not one-dimensional.
8084
:raises ValueError: If ``control_points`` is not one-dimensional.
@@ -87,6 +91,8 @@ def __init__(self, order=4, knots=None, control_points=None):
8791

8892
# Check consistency
8993
check_positive_integer(value=order, strict=True)
94+
check_consistency(knots, (type(None), torch.Tensor, dict))
95+
check_consistency(control_points, (type(None), torch.Tensor))
9096

9197
# Raise error if neither knots nor control points are provided
9298
if knots is None and control_points is None:
@@ -229,10 +235,10 @@ def forward(self, x):
229235
:rtype: torch.Tensor
230236
"""
231237
return torch.einsum(
232-
"bi, i -> b",
233-
self.basis(x.as_subclass(torch.Tensor)).squeeze(1),
238+
"...bi, i -> ...b",
239+
self.basis(x.as_subclass(torch.Tensor)).squeeze(-1),
234240
self.control_points,
235-
).reshape(-1, 1)
241+
)
236242

237243
@property
238244
def control_points(self):
@@ -254,7 +260,6 @@ def control_points(self, control_points):
254260
initial value. Default is None.
255261
:raises ValueError: If there are not enough knots to define the control
256262
points, due to the relation: #knots = order + #control_points.
257-
:raises ValueError: If control_points is not a torch.Tensor.
258263
"""
259264
# If control points are not provided, initialize them
260265
if control_points is None:
@@ -270,13 +275,6 @@ def control_points(self, control_points):
270275
# Initialize control points to zero
271276
control_points = torch.zeros(len(self.knots) - self.order)
272277

273-
# Check validity of control points
274-
elif not isinstance(control_points, torch.Tensor):
275-
raise ValueError(
276-
"control_points must be a torch.Tensor,"
277-
f" got {type(control_points)}"
278-
)
279-
280278
# Set control points
281279
self._control_points = torch.nn.Parameter(
282280
control_points, requires_grad=True
@@ -308,18 +306,10 @@ def knots(self, value):
308306
last control points. In this case, the number of knots is inferred
309307
and the ``"n"`` key is ignored.
310308
:type value: torch.Tensor | dict
311-
:raises ValueError: If value is not a torch.Tensor or a dictionary.
312309
:raises ValueError: If a dictionary is provided but does not contain
313310
the required keys.
314311
:raises ValueError: If the mode specified in the dictionary is invalid.
315312
"""
316-
# Check validity of knots
317-
if not isinstance(value, (torch.Tensor, dict)):
318-
raise ValueError(
319-
"Knots must be a torch.Tensor or a dictionary,"
320-
f" got {type(value)}."
321-
)
322-
323313
# If a dictionary is provided, initialize knots accordingly
324314
if isinstance(value, dict):
325315

pina/model/spline_surface.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""Module for the bivariate B-Spline surface model class."""
2+
3+
import torch
4+
from .spline import Spline
5+
from ..utils import check_consistency
6+
7+
8+
class SplineSurface(torch.nn.Module):
9+
r"""
10+
The bivariate B-Spline surface model class.
11+
12+
A bivariate B-spline surface is a parametric surface defined as the tensor
13+
product of two univariate B-spline curves:
14+
15+
.. math::
16+
17+
S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j},
18+
\quad x \in [x_1, x_m], y \in [y_1, y_l]
19+
20+
where:
21+
22+
- :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed
23+
points influence the shape of the surface but are not generally
24+
interpolated, except at the boundaries under certain knot multiplicities.
25+
- :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions
26+
defined over two orthogonal directions, with orders :math:`k` and
27+
:math:`s`, respectively.
28+
- :math:`X = \{ x_1, x_2, \dots, x_m \}` and
29+
:math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot
30+
vectors along the two directions.
31+
"""
32+
33+
def __init__(self, orders, knots_u=None, knots_v=None, control_points=None):
34+
"""
35+
Initialization of the :class:`SplineSurface` class.
36+
37+
:param list[int] orders: The orders of the spline along each parametric
38+
direction. Each order defines the degree of the corresponding basis
39+
as ``degree = order - 1``.
40+
:param knots_u: The knots of the spline along the first direction.
41+
For details on valid formats and initialization modes, see the
42+
:class:`Spline` class. Default is None.
43+
:type knots_u: torch.Tensor | dict
44+
:param knots_v: The knots of the spline along the second direction.
45+
For details on valid formats and initialization modes, see the
46+
:class:`Spline` class. Default is None.
47+
:type knots_v: torch.Tensor | dict
48+
:param torch.Tensor control_points: The control points defining the
49+
surface geometry. It must be a two-dimensional tensor of shape
50+
``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``.
51+
If None, they are initialized as learnable parameters with zero
52+
values. Default is None.
53+
:raises ValueError: If ``orders`` is not a list of integers.
54+
:raises ValueError: If ``knots_u`` is neither a torch.Tensor nor a
55+
dictionary, when provided.
56+
:raises ValueError: If ``knots_v`` is neither a torch.Tensor nor a
57+
dictionary, when provided.
58+
:raises ValueError: If ``control_points`` is not a torch.Tensor,
59+
when provided.
60+
:raises ValueError: If ``orders`` is not a list of two elements.
61+
:raises ValueError: If ``knots_u``, ``knots_v``, and ``control_points``
62+
are all None.
63+
"""
64+
super().__init__()
65+
66+
# Check consistency
67+
check_consistency(orders, int)
68+
check_consistency(control_points, (type(None), torch.Tensor))
69+
check_consistency(knots_u, (type(None), torch.Tensor, dict))
70+
check_consistency(knots_v, (type(None), torch.Tensor, dict))
71+
72+
# Check orders is a list of two elements
73+
if len(orders) != 2:
74+
raise ValueError("orders must be a list of two elements.")
75+
76+
# Raise error if neither knots nor control points are provided
77+
if (knots_u is None or knots_v is None) and control_points is None:
78+
raise ValueError(
79+
"control_points cannot be None if knots_u or knots_v is None."
80+
)
81+
82+
# Initialize knots_u if not provided
83+
if knots_u is None and control_points is not None:
84+
knots_u = {
85+
"n": control_points.shape[0] + orders[0],
86+
"min": 0,
87+
"max": 1,
88+
"mode": "auto",
89+
}
90+
91+
# Initialize knots_v if not provided
92+
if knots_v is None and control_points is not None:
93+
knots_v = {
94+
"n": control_points.shape[1] + orders[1],
95+
"min": 0,
96+
"max": 1,
97+
"mode": "auto",
98+
}
99+
100+
# Create two univariate b-splines
101+
self.spline_u = Spline(order=orders[0], knots=knots_u)
102+
self.spline_v = Spline(order=orders[1], knots=knots_v)
103+
self.control_points = control_points
104+
105+
# Delete unneeded parameters
106+
delattr(self.spline_u, "_control_points")
107+
delattr(self.spline_v, "_control_points")
108+
109+
def forward(self, x):
110+
"""
111+
Forward pass for the :class:`SplineSurface` model.
112+
113+
:param x: The input tensor.
114+
:type x: torch.Tensor | LabelTensor
115+
:return: The output tensor.
116+
:rtype: torch.Tensor
117+
"""
118+
return torch.einsum(
119+
"...bi, ...bj, ij -> ...b",
120+
self.spline_u.basis(x.as_subclass(torch.Tensor)[..., 0]),
121+
self.spline_v.basis(x.as_subclass(torch.Tensor)[..., 1]),
122+
self.control_points,
123+
).unsqueeze(-1)
124+
125+
@property
126+
def knots(self):
127+
"""
128+
The knots of the univariate splines defining the spline surface.
129+
130+
:return: The knots.
131+
:rtype: tuple(torch.Tensor, torch.Tensor)
132+
"""
133+
return self.spline_u.knots, self.spline_v.knots
134+
135+
@knots.setter
136+
def knots(self, value):
137+
"""
138+
Set the knots of the spline surface.
139+
140+
:param value: A tuple (knots_u, knots_v) containing the knots for both
141+
parametric directions.
142+
:type value: tuple(torch.Tensor | dict, torch.Tensor | dict)
143+
:raises ValueError: If value is not a tuple of two elements.
144+
"""
145+
# Check value is a tuple of two elements
146+
if not (isinstance(value, tuple) and len(value) == 2):
147+
raise ValueError("Knots must be a tuple of two elements.")
148+
149+
knots_u, knots_v = value
150+
self.spline_u.knots = knots_u
151+
self.spline_v.knots = knots_v
152+
153+
@property
154+
def control_points(self):
155+
"""
156+
The control points of the spline.
157+
158+
:return: The control points.
159+
:rtype: torch.Tensor
160+
"""
161+
return self._control_points
162+
163+
@control_points.setter
164+
def control_points(self, control_points):
165+
"""
166+
Set the control points of the spline surface.
167+
168+
:param torch.Tensor control_points: The bidimensional control points
169+
tensor, where each dimension refers to a direction in the parameter
170+
space. If None, control points are initialized to learnable
171+
parameters with zero initial value. Default is None.
172+
:raises ValueError: If in any direction there are not enough knots to
173+
define the control points, due to the relation:
174+
#knots = order + #control_points.
175+
:raises ValueError: If ``control_points`` is not of the correct shape.
176+
"""
177+
# Save correct shape of control points
178+
__valid_shape = (
179+
len(self.spline_u.knots) - self.spline_u.order,
180+
len(self.spline_v.knots) - self.spline_v.order,
181+
)
182+
183+
# If control points are not provided, initialize them
184+
if control_points is None:
185+
186+
# Check that there are enough knots to define control points
187+
if (
188+
len(self.spline_u.knots) < self.spline_u.order + 1
189+
or len(self.spline_v.knots) < self.spline_v.order + 1
190+
):
191+
raise ValueError(
192+
f"Not enough knots to define control points. Got "
193+
f"{len(self.spline_u.knots)} knots along u and "
194+
f"{len(self.spline_v.knots)} knots along v, but need at "
195+
f"least {self.spline_u.order + 1} and "
196+
f"{self.spline_v.order + 1}, respectively."
197+
)
198+
199+
# Initialize control points to zero
200+
control_points = torch.zeros(__valid_shape)
201+
202+
# Check control points
203+
if control_points.shape != __valid_shape:
204+
raise ValueError(
205+
"control_points must be of the correct shape. ",
206+
f"Expected {__valid_shape}, got {control_points.shape}.",
207+
)
208+
209+
# Register control points as a learnable parameter
210+
self._control_points = torch.nn.Parameter(
211+
control_points, requires_grad=True
212+
)

tests/test_model/test_spline.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import pytest
3-
import numpy as np
43
from scipy.interpolate import BSpline
54
from pina.model import Spline
65
from pina import LabelTensor
@@ -12,7 +11,10 @@
1211
n_knots = order + n_ctrl_pts
1312

1413
# Input tensor
15-
pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"])
14+
points = [
15+
LabelTensor(torch.rand(100, 1), ["x"]),
16+
LabelTensor(torch.rand(2, 100, 1), ["x"]),
17+
]
1618

1719

1820
# Function to compare with scipy implementation
@@ -26,15 +28,15 @@ def check_scipy_spline(model, x, output_):
2628
)
2729

2830
# Compare outputs
29-
np.testing.assert_allclose(
30-
output_.squeeze().detach().numpy(),
31-
scipy_spline(x).flatten(),
31+
torch.allclose(
32+
output_,
33+
torch.tensor(scipy_spline(x), dtype=output_.dtype),
3234
atol=1e-5,
3335
rtol=1e-5,
3436
)
3537

3638

37-
# Define all possible combinations of valid arguments for the Spline class
39+
# Define all possible combinations of valid arguments for Spline class
3840
valid_args = [
3941
{
4042
"order": order,
@@ -144,22 +146,24 @@ def test_constructor(args):
144146

145147

146148
@pytest.mark.parametrize("args", valid_args)
147-
def test_forward(args):
149+
@pytest.mark.parametrize("pts", points)
150+
def test_forward(args, pts):
148151

149152
# Define the model
150153
model = Spline(**args)
151154

152155
# Evaluate the model
153156
output_ = model(pts)
154-
assert output_.shape == (pts.shape[0], 1)
157+
assert output_.shape == pts.shape
155158

156159
# Compare with scipy implementation only for interpolant knots (mode: auto)
157160
if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto":
158161
check_scipy_spline(model, pts, output_)
159162

160163

161164
@pytest.mark.parametrize("args", valid_args)
162-
def test_backward(args):
165+
@pytest.mark.parametrize("pts", points)
166+
def test_backward(args, pts):
163167

164168
# Define the model
165169
model = Spline(**args)

0 commit comments

Comments
 (0)