Skip to content

Commit f812d87

Browse files
🎨 Format Python code with psf/black (#348)
1 parent 4c5cb8f commit f812d87

File tree

2 files changed

+56
-56
lines changed

2 files changed

+56
-56
lines changed

pina/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
from .base_no import KernelNeuralOperator
2020
from .avno import AveragingNeuralOperator
2121
from .lno import LowRankNeuralOperator
22-
from .spline import Spline
22+
from .spline import Spline

pina/model/spline.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
import torch.nn as nn
55
from ..utils import check_consistency
6-
6+
7+
78
class Spline(torch.nn.Module):
89

910
def __init__(self, order=4, knots=None, control_points=None) -> None:
@@ -31,38 +32,37 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
3132
self.control_points = control_points
3233

3334
elif knots is not None:
34-
print('Warning: control points will be initialized automatically.')
35-
print(' experimental feature')
35+
print("Warning: control points will be initialized automatically.")
36+
print(" experimental feature")
3637

3738
self.knots = knots
3839
n = len(knots) - order
3940
self.control_points = torch.nn.Parameter(
40-
torch.zeros(n), requires_grad=True)
41-
41+
torch.zeros(n), requires_grad=True
42+
)
43+
4244
elif control_points is not None:
43-
print('Warning: knots will be initialized automatically.')
44-
print(' experimental feature')
45-
45+
print("Warning: knots will be initialized automatically.")
46+
print(" experimental feature")
47+
4648
self.control_points = control_points
4749

48-
n = len(self.control_points)-1
50+
n = len(self.control_points) - 1
4951
self.knots = {
50-
'type': 'auto',
51-
'min': 0,
52-
'max': 1,
53-
'n': n+2+self.order}
52+
"type": "auto",
53+
"min": 0,
54+
"max": 1,
55+
"n": n + 2 + self.order,
56+
}
5457

5558
else:
56-
raise ValueError(
57-
"Knots and control points cannot be both None."
58-
)
59-
59+
raise ValueError("Knots and control points cannot be both None.")
6060

6161
if self.knots.ndim != 1:
6262
raise ValueError("Knot vector must be one-dimensional.")
6363

6464
def basis(self, x, k, i, t):
65-
'''
65+
"""
6666
Recursive function to compute the basis functions of the spline.
6767
6868
:param torch.Tensor x: points to be evaluated.
@@ -71,28 +71,32 @@ def basis(self, x, k, i, t):
7171
:param torch.Tensor t: vector of knots
7272
:return: the basis functions evaluated at x
7373
:rtype: torch.Tensor
74-
'''
75-
74+
"""
75+
7676
if k == 0:
77-
a = torch.where(torch.logical_and(t[i] <= x, x < t[i+1]), 1.0, 0.0)
77+
a = torch.where(
78+
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0
79+
)
7880
if i == len(t) - self.order - 1:
79-
a = torch.where(x == t[-1], 1.0, a)
81+
a = torch.where(x == t[-1], 1.0, a)
8082
a.requires_grad_(True)
8183
return a
8284

83-
84-
if t[i+k] == t[i]:
85-
c1 = torch.tensor([0.0]*len(x), requires_grad=True)
85+
if t[i + k] == t[i]:
86+
c1 = torch.tensor([0.0] * len(x), requires_grad=True)
8687
else:
87-
c1 = (x - t[i])/(t[i+k] - t[i]) * self.basis(x, k-1, i, t)
88+
c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t)
8889

89-
if t[i+k+1] == t[i+1]:
90-
c2 = torch.tensor([0.0]*len(x), requires_grad=True)
90+
if t[i + k + 1] == t[i + 1]:
91+
c2 = torch.tensor([0.0] * len(x), requires_grad=True)
9192
else:
92-
c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * self.basis(x, k-1, i+1, t)
93+
c2 = (
94+
(t[i + k + 1] - x)
95+
/ (t[i + k + 1] - t[i + 1])
96+
* self.basis(x, k - 1, i + 1, t)
97+
)
9398

9499
return c1 + c2
95-
96100

97101
@property
98102
def control_points(self):
@@ -101,50 +105,46 @@ def control_points(self):
101105
@control_points.setter
102106
def control_points(self, value):
103107
if isinstance(value, dict):
104-
if 'n' not in value:
105-
raise ValueError('Invalid value for control_points')
106-
n = value['n']
107-
dim = value.get('dim', 1)
108+
if "n" not in value:
109+
raise ValueError("Invalid value for control_points")
110+
n = value["n"]
111+
dim = value.get("dim", 1)
108112
value = torch.zeros(n, dim)
109113

110114
if not isinstance(value, torch.Tensor):
111-
raise ValueError('Invalid value for control_points')
115+
raise ValueError("Invalid value for control_points")
112116
self._control_points = torch.nn.Parameter(value, requires_grad=True)
113117

114118
@property
115119
def knots(self):
116120
return self._knots
117-
121+
118122
@knots.setter
119123
def knots(self, value):
120124
if isinstance(value, dict):
121125

122-
type_ = value.get('type', 'auto')
123-
min_ = value.get('min', 0)
124-
max_ = value.get('max', 1)
125-
n = value.get('n', 10)
126+
type_ = value.get("type", "auto")
127+
min_ = value.get("min", 0)
128+
max_ = value.get("max", 1)
129+
n = value.get("n", 10)
126130

127-
if type_ == 'uniform':
131+
if type_ == "uniform":
128132
value = torch.linspace(min_, max_, n + self.k + 1)
129-
elif type_ == 'auto':
130-
initial_knots = torch.ones(self.order+1)*min_
131-
final_knots = torch.ones(self.order+1)*max_
133+
elif type_ == "auto":
134+
initial_knots = torch.ones(self.order + 1) * min_
135+
final_knots = torch.ones(self.order + 1) * max_
132136

133137
if n < self.order + 1:
134138
value = torch.concatenate((initial_knots, final_knots))
135-
elif n - 2*self.order + 1 == 1:
136-
value = torch.Tensor([(max_ + min_)/2])
139+
elif n - 2 * self.order + 1 == 1:
140+
value = torch.Tensor([(max_ + min_) / 2])
137141
else:
138-
value = torch.linspace(min_, max_, n - 2*self.order - 1)
142+
value = torch.linspace(min_, max_, n - 2 * self.order - 1)
139143

140-
value = torch.concatenate(
141-
(
142-
initial_knots, value, final_knots
143-
)
144-
)
144+
value = torch.concatenate((initial_knots, value, final_knots))
145145

146146
if not isinstance(value, torch.Tensor):
147-
raise ValueError('Invalid value for knots')
147+
raise ValueError("Invalid value for knots")
148148

149149
self._knots = value
150150

@@ -154,7 +154,7 @@ def forward(self, x_):
154154
155155
:param torch.Tensor x_: points to be evaluated.
156156
:return: the spline evaluated at x_
157-
:rtype: torch.Tensor
157+
:rtype: torch.Tensor
158158
"""
159159
t = self.knots
160160
k = self.k
@@ -163,4 +163,4 @@ def forward(self, x_):
163163
basis = map(lambda i: self.basis(x_, k, i, t)[:, None], range(len(c)))
164164
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
165165

166-
return y
166+
return y

0 commit comments

Comments
 (0)