Skip to content

Commit 4c5cb8f

Browse files
ndem0AleDinvedario-coscia
authored
add spline model (#321)
* add spline model * add tests for splines * rst files for splines --------- Co-authored-by: AleDinve <[email protected]> Co-authored-by: dario-coscia <[email protected]>
1 parent fefba81 commit 4c5cb8f

File tree

6 files changed

+255
-1
lines changed

6 files changed

+255
-1
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Models
5959
FeedForward <models/fnn.rst>
6060
MultiFeedForward <models/multifeedforward.rst>
6161
ResidualFeedForward <models/fnn_residual.rst>
62+
Spline <models/spline.rst>
6263
DeepONet <models/deeponet.rst>
6364
MIONet <models/mionet.rst>
6465
FourierIntegralKernel <models/fourier_kernel.rst>

docs/source/_rst/models/spline.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Spline
2+
========
3+
.. currentmodule:: pina.model.spline
4+
5+
.. autoclass:: Spline
6+
:members:
7+
:show-inheritance:

pina/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"KernelNeuralOperator",
1010
"AveragingNeuralOperator",
1111
"LowRankNeuralOperator",
12+
"Spline",
1213
]
1314

1415
from .feed_forward import FeedForward, ResidualFeedForward
@@ -18,3 +19,4 @@
1819
from .base_no import KernelNeuralOperator
1920
from .avno import AveragingNeuralOperator
2021
from .lno import LowRankNeuralOperator
22+
from .spline import Spline

pina/model/spline.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""Module for Spline model"""
2+
3+
import torch
4+
import torch.nn as nn
5+
from ..utils import check_consistency
6+
7+
class Spline(torch.nn.Module):
8+
9+
def __init__(self, order=4, knots=None, control_points=None) -> None:
10+
"""
11+
Spline model.
12+
13+
:param int order: the order of the spline.
14+
:param torch.Tensor knots: the knot vector.
15+
:param torch.Tensor control_points: the control points.
16+
"""
17+
super().__init__()
18+
19+
check_consistency(order, int)
20+
21+
if order < 0:
22+
raise ValueError("Spline order cannot be negative.")
23+
if knots is None and control_points is None:
24+
raise ValueError("Knots and control points cannot be both None.")
25+
26+
self.order = order
27+
self.k = order - 1
28+
29+
if knots is not None and control_points is not None:
30+
self.knots = knots
31+
self.control_points = control_points
32+
33+
elif knots is not None:
34+
print('Warning: control points will be initialized automatically.')
35+
print(' experimental feature')
36+
37+
self.knots = knots
38+
n = len(knots) - order
39+
self.control_points = torch.nn.Parameter(
40+
torch.zeros(n), requires_grad=True)
41+
42+
elif control_points is not None:
43+
print('Warning: knots will be initialized automatically.')
44+
print(' experimental feature')
45+
46+
self.control_points = control_points
47+
48+
n = len(self.control_points)-1
49+
self.knots = {
50+
'type': 'auto',
51+
'min': 0,
52+
'max': 1,
53+
'n': n+2+self.order}
54+
55+
else:
56+
raise ValueError(
57+
"Knots and control points cannot be both None."
58+
)
59+
60+
61+
if self.knots.ndim != 1:
62+
raise ValueError("Knot vector must be one-dimensional.")
63+
64+
def basis(self, x, k, i, t):
65+
'''
66+
Recursive function to compute the basis functions of the spline.
67+
68+
:param torch.Tensor x: points to be evaluated.
69+
:param int k: spline degree
70+
:param int i: the index of the interval
71+
:param torch.Tensor t: vector of knots
72+
:return: the basis functions evaluated at x
73+
:rtype: torch.Tensor
74+
'''
75+
76+
if k == 0:
77+
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0)
78+
if i == len(t) - self.order - 1:
79+
a = torch.where(x == t[-1], 1.0, a)
80+
a.requires_grad_(True)
81+
return a
82+
83+
84+
if t[i+k] == t[i]:
85+
c1 = torch.tensor([0.0]*len(x), requires_grad=True)
86+
else:
87+
c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t)
88+
89+
if t[i+k+1] == t[i+1]:
90+
c2 = torch.tensor([0.0]*len(x), requires_grad=True)
91+
else:
92+
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t)
93+
94+
return c1 + c2
95+
96+
97+
@property
98+
def control_points(self):
99+
return self._control_points
100+
101+
@control_points.setter
102+
def control_points(self, value):
103+
if isinstance(value, dict):
104+
if 'n' not in value:
105+
raise ValueError('Invalid value for control_points')
106+
n = value['n']
107+
dim = value.get('dim', 1)
108+
value = torch.zeros(n, dim)
109+
110+
if not isinstance(value, torch.Tensor):
111+
raise ValueError('Invalid value for control_points')
112+
self._control_points = torch.nn.Parameter(value, requires_grad=True)
113+
114+
@property
115+
def knots(self):
116+
return self._knots
117+
118+
@knots.setter
119+
def knots(self, value):
120+
if isinstance(value, dict):
121+
122+
type_ = value.get('type', 'auto')
123+
min_ = value.get('min', 0)
124+
max_ = value.get('max', 1)
125+
n = value.get('n', 10)
126+
127+
if type_ == 'uniform':
128+
value = torch.linspace(min_, max_, n + self.k + 1)
129+
elif type_ == 'auto':
130+
initial_knots = torch.ones(self.order+1)*min_
131+
final_knots = torch.ones(self.order+1)*max_
132+
133+
if n < self.order + 1:
134+
value = torch.concatenate((initial_knots, final_knots))
135+
elif n - 2*self.order + 1 == 1:
136+
value = torch.Tensor([(max_ + min_)/2])
137+
else:
138+
value = torch.linspace(min_, max_, n - 2*self.order - 1)
139+
140+
value = torch.concatenate(
141+
(
142+
initial_knots, value, final_knots
143+
)
144+
)
145+
146+
if not isinstance(value, torch.Tensor):
147+
raise ValueError('Invalid value for knots')
148+
149+
self._knots = value
150+
151+
def forward(self, x_):
152+
"""
153+
Forward pass of the spline model.
154+
155+
:param torch.Tensor x_: points to be evaluated.
156+
:return: the spline evaluated at x_
157+
:rtype: torch.Tensor
158+
"""
159+
t = self.knots
160+
k = self.k
161+
c = self.control_points
162+
163+
basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c)))
164+
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
165+
166+
return y

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
'sphinx_design',
2727
'pydata_sphinx_theme'
2828
],
29-
'test': ['pytest', 'pytest-cov'],
29+
'test': [
30+
'pytest',
31+
'pytest-cov',
32+
'scipy'
33+
],
3034
}
3135

3236
LDESCRIPTION = (

tests/test_model/test_spline.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import pytest
3+
4+
from pina.model import Spline
5+
6+
data = torch.rand((20, 3))
7+
input_vars = 3
8+
output_vars = 4
9+
10+
valid_args = [
11+
{
12+
'knots': torch.tensor([0., 0., 0., 1., 2., 3., 3., 3.]),
13+
'control_points': torch.tensor([0., 0., 1., 0., 0.]),
14+
'order': 3
15+
},
16+
{
17+
'knots': torch.tensor([-2., -2., -2., -2., -1., 0., 1., 2., 2., 2., 2.]),
18+
'control_points': torch.tensor([0., 0., 0., 6., 0., 0., 0.]),
19+
'order': 4
20+
},
21+
# {'control_points': {'n': 5, 'dim': 1}, 'order': 2},
22+
# {'control_points': {'n': 7, 'dim': 1}, 'order': 3}
23+
]
24+
25+
def scipy_check(model, x, y):
26+
from scipy.interpolate._bsplines import BSpline
27+
import numpy as np
28+
spline = BSpline(
29+
t=model.knots.detach().numpy(),
30+
c=model.control_points.detach().numpy(),
31+
k=model.order-1
32+
)
33+
y_scipy = spline(x).flatten()
34+
y = y.detach().numpy()
35+
np.testing.assert_allclose(y, y_scipy, atol=1e-5)
36+
37+
@pytest.mark.parametrize("args", valid_args)
38+
def test_constructor(args):
39+
Spline(**args)
40+
41+
def test_constructor_wrong():
42+
with pytest.raises(ValueError):
43+
Spline()
44+
45+
@pytest.mark.parametrize("args", valid_args)
46+
def test_forward(args):
47+
min_x = args['knots'][0]
48+
max_x = args['knots'][-1]
49+
xi = torch.linspace(min_x, max_x, 1000)
50+
model = Spline(**args)
51+
yi = model(xi).squeeze()
52+
scipy_check(model, xi, yi)
53+
return
54+
55+
56+
@pytest.mark.parametrize("args", valid_args)
57+
def test_backward(args):
58+
min_x = args['knots'][0]
59+
max_x = args['knots'][-1]
60+
xi = torch.linspace(min_x, max_x, 100)
61+
model = Spline(**args)
62+
yi = model(xi)
63+
fake_loss = torch.sum(yi)
64+
assert model.control_points.grad is None
65+
fake_loss.backward()
66+
assert model.control_points.grad is not None
67+
68+
# dim_in, dim_out = 3, 2
69+
# fnn = FeedForward(dim_in, dim_out)
70+
# data.requires_grad = True
71+
# output_ = fnn(data)
72+
# l=torch.mean(output_)
73+
# l.backward()
74+
# assert data._grad.shape == torch.Size([20,3])

0 commit comments

Comments
 (0)