Skip to content

Commit 8b0fc2b

Browse files
add b-spline surface
1 parent f90ca66 commit 8b0fc2b

File tree

6 files changed

+309
-1
lines changed

6 files changed

+309
-1
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Models
9595
MultiFeedForward <model/multi_feed_forward.rst>
9696
ResidualFeedForward <model/residual_feed_forward.rst>
9797
Spline <model/spline.rst>
98+
SplineSurface <model/spline_surface.rst>
9899
DeepONet <model/deeponet.rst>
99100
MIONet <model/mionet.rst>
100101
KernelNeuralOperator <model/kernel_neural_operator.rst>
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Spline Surface
2+
================
3+
.. currentmodule:: pina.model.spline_surface
4+
5+
.. autoclass:: SplineSurface
6+
:members:
7+
:show-inheritance:

pina/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .average_neural_operator import AveragingNeuralOperator
2626
from .low_rank_neural_operator import LowRankNeuralOperator
2727
from .spline import Spline
28+
from .spline_surface import SplineSurface
2829
from .graph_neural_operator import GraphNeuralOperator
2930
from .pirate_network import PirateNet
3031
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator

pina/model/spline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module for the B-Spline model class."""
22

3-
import torch
43
import warnings
4+
import torch
55
from ..utils import check_positive_integer
66

77

pina/model/spline_surface.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Module for the bivariate B-Spline surface model class."""
2+
3+
import torch
4+
from .spline import Spline
5+
from ..utils import check_consistency
6+
7+
8+
class SplineSurface(torch.nn.Module):
9+
r"""
10+
The bivariate B-Spline surface model class.
11+
12+
A bivariate B-spline surface is a parametric surface defined as the tensor
13+
product of two univariate B-spline curves:
14+
15+
.. math::
16+
17+
S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j},
18+
\quad x \in [x_1, x_m], y \in [y_1, y_l]
19+
20+
where:
21+
22+
- :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed
23+
points influence the shape of the surface but are not generally
24+
interpolated, except at the boundaries under certain knot multiplicities.
25+
- :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions
26+
defined over two orthogonal directions, with orders :math:`k` and
27+
:math:`s`, respectively.
28+
- :math:`X = \{ x_1, x_2, \dots, x_m \}` and
29+
:math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot
30+
vectors along the two directions.
31+
"""
32+
33+
def __init__(self, knots_u, knots_v, orders=[4, 4], control_points=None):
34+
"""
35+
Initialization of the :class:`SplineSurface` class.
36+
37+
:param knots_u: The knots of the spline along the first direction.
38+
Unlike the univariate case, this must be explicitly provided.
39+
For details on valid formats and initialization modes, see the
40+
:class:`Spline` class.
41+
:type knots_u: torch.Tensor | dict
42+
:param knots_v: The knots of the spline along the second direction.
43+
Unlike the univariate case, this must be explicitly provided.
44+
For details on valid formats and initialization modes, see the
45+
:class:`Spline` class.
46+
:type knots_v: torch.Tensor | dict
47+
:param list[int] orders: The orders of the spline along each parametric
48+
direction. Each order defines the degree of the corresponding basis
49+
as ``degree = order - 1``. Default is ``[4, 4]``.
50+
:param torch.Tensor control_points: The control points defining the
51+
surface geometry. It must be a two-dimensional tensor of shape
52+
``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``.
53+
If None, they are initialized as learnable parameters with zero
54+
values. Default is None.
55+
:raises ValueError: If ``orders`` is not a list of two elements.
56+
:raises ValueError: If ``knots_u`` or ``knots_v`` is None.
57+
:raises ValueError: If ``control_points`` is not a torch.Tensor when
58+
provided.
59+
:raises ValueError: If ``control_points`` is not of the correct shape
60+
when provided.
61+
"""
62+
super().__init__()
63+
64+
# Check consistency
65+
check_consistency(orders, int)
66+
check_consistency(control_points, (type(None), torch.Tensor))
67+
68+
# Check orders is a list of two elements
69+
if len(orders) != 2:
70+
raise ValueError("orders must be a list of two elements.")
71+
72+
# Check knots_u and knots_v are not None
73+
if knots_u is None or knots_v is None:
74+
raise ValueError("knots_u and knots_v must cannot be None.")
75+
76+
# Create two univariate b-splines
77+
self.spline_u = Spline(order=orders[0], knots=knots_u)
78+
self.spline_v = Spline(order=orders[1], knots=knots_v)
79+
80+
# Delete unneeded parameters
81+
delattr(self.spline_u, "_control_points")
82+
delattr(self.spline_v, "_control_points")
83+
84+
# Save correct shape of control points
85+
__valid_shape = (
86+
len(self.spline_u.knots) - self.spline_u.order,
87+
len(self.spline_v.knots) - self.spline_v.order,
88+
)
89+
90+
# Initialize control points, if not provided
91+
if control_points is None:
92+
control_points = torch.zeros(__valid_shape)
93+
94+
# Check control points
95+
if control_points.shape != __valid_shape:
96+
raise ValueError(
97+
"control_points must be of the correct shape. ",
98+
f"Expected {__valid_shape}, got {control_points.shape}.",
99+
)
100+
101+
# Register control points as a learnable parameter
102+
self._control_points = torch.nn.Parameter(
103+
control_points, requires_grad=True
104+
)
105+
106+
def forward(self, x):
107+
"""
108+
Forward pass for the :class:`SplineSurface` model.
109+
110+
:param x: The input tensor.
111+
:type x: torch.Tensor | LabelTensor
112+
:return: The output tensor.
113+
:rtype: torch.Tensor
114+
"""
115+
return torch.einsum(
116+
"bi, bj, ij -> b",
117+
self.spline_u.basis(x.as_subclass(torch.Tensor)[:, 0]),
118+
self.spline_v.basis(x.as_subclass(torch.Tensor)[:, 1]),
119+
self.control_points,
120+
).reshape(-1, 1)
121+
122+
@property
123+
def knots(self):
124+
"""
125+
The knots of the univariate splines defining the spline surface.
126+
127+
:return: The knots.
128+
:rtype: tuple(torch.Tensor, torch.Tensor)
129+
"""
130+
return self.spline_u.knots, self.spline_v.knots
131+
132+
@property
133+
def control_points(self):
134+
"""
135+
The control points of the spline.
136+
137+
:return: The control points.
138+
:rtype: torch.Tensor
139+
"""
140+
return self._control_points
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import torch
2+
import random
3+
import pytest
4+
from pina.model import SplineSurface
5+
from pina import LabelTensor
6+
7+
8+
# Utility quantities for testing
9+
orders = [random.randint(1, 8) for _ in range(2)]
10+
n_ctrl_pts = random.randint(max(orders), max(orders) + 5)
11+
n_knots = [orders[i] + n_ctrl_pts for i in range(2)]
12+
13+
# Input tensor
14+
x = torch.linspace(0, 1, 100).reshape(-1, 1)
15+
y = torch.linspace(0, 1, 100).reshape(-1, 1)
16+
pts = LabelTensor(torch.cat((x, y), dim=1), labels=["x", "y"])
17+
18+
19+
@pytest.mark.parametrize(
20+
"knots_u",
21+
[
22+
torch.rand(n_knots[0]),
23+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
24+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
25+
],
26+
)
27+
@pytest.mark.parametrize(
28+
"knots_v",
29+
[
30+
torch.rand(n_knots[1]),
31+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
32+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
33+
],
34+
)
35+
@pytest.mark.parametrize(
36+
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
37+
)
38+
def test_constructor(knots_u, knots_v, control_points):
39+
SplineSurface(
40+
knots_u=knots_u,
41+
knots_v=knots_v,
42+
control_points=control_points,
43+
orders=orders,
44+
)
45+
46+
# Should fail if orders is not list of two elements
47+
with pytest.raises(ValueError):
48+
SplineSurface(
49+
knots_u=knots_u,
50+
knots_v=knots_v,
51+
control_points=control_points,
52+
orders=[orders[0]],
53+
)
54+
55+
# Should fail if knots_u is None
56+
with pytest.raises(ValueError):
57+
SplineSurface(
58+
knots_u=None,
59+
knots_v=knots_v,
60+
control_points=control_points,
61+
orders=orders,
62+
)
63+
64+
# Should fail if knots_v is None
65+
with pytest.raises(ValueError):
66+
SplineSurface(
67+
knots_u=knots_u,
68+
knots_v=None,
69+
control_points=control_points,
70+
orders=orders,
71+
)
72+
73+
# Should fail if control_points is not a torch.Tensor when provided
74+
with pytest.raises(ValueError):
75+
SplineSurface(
76+
knots_u=knots_u,
77+
knots_v=knots_v,
78+
control_points=[[0.0] * n_ctrl_pts] * n_ctrl_pts,
79+
orders=orders,
80+
)
81+
82+
# Should fail if control_points is not of the correct shape when provided
83+
with pytest.raises(ValueError):
84+
SplineSurface(
85+
knots_u=knots_u,
86+
knots_v=knots_v,
87+
control_points=torch.rand(n_ctrl_pts + 1, n_ctrl_pts),
88+
orders=orders,
89+
)
90+
91+
92+
@pytest.mark.parametrize(
93+
"knots_u",
94+
[
95+
torch.rand(n_knots[0]),
96+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
97+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
98+
],
99+
)
100+
@pytest.mark.parametrize(
101+
"knots_v",
102+
[
103+
torch.rand(n_knots[1]),
104+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
105+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
106+
],
107+
)
108+
@pytest.mark.parametrize(
109+
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
110+
)
111+
def test_forward(knots_u, knots_v, control_points):
112+
113+
# Define the model
114+
model = SplineSurface(
115+
knots_u=knots_u,
116+
knots_v=knots_v,
117+
control_points=control_points,
118+
orders=orders,
119+
)
120+
121+
# Evaluate the model
122+
output_ = model(pts)
123+
assert output_.shape == (pts.shape[0], 1)
124+
125+
126+
@pytest.mark.parametrize(
127+
"knots_u",
128+
[
129+
torch.rand(n_knots[0]),
130+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"},
131+
{"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"},
132+
],
133+
)
134+
@pytest.mark.parametrize(
135+
"knots_v",
136+
[
137+
torch.rand(n_knots[1]),
138+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"},
139+
{"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"},
140+
],
141+
)
142+
@pytest.mark.parametrize(
143+
"control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None]
144+
)
145+
def test_backward(knots_u, knots_v, control_points):
146+
147+
# Define the model
148+
model = SplineSurface(
149+
knots_u=knots_u,
150+
knots_v=knots_v,
151+
control_points=control_points,
152+
orders=orders,
153+
)
154+
155+
# Evaluate the model
156+
output_ = model(pts)
157+
loss = torch.mean(output_)
158+
loss.backward()
159+
assert model.control_points.grad.shape == model.control_points.shape

0 commit comments

Comments
 (0)