Skip to content

Commit ad41ba0

Browse files
ajacoby9FilippoOlivo
authored andcommitted
vectorize Cox - de Boor recursion
Co-authored-by: Filippo Olivo <[email protected]> Co-authored-by: ajacoby9 <[email protected]>
1 parent 4a3748c commit ad41ba0

File tree

1 file changed

+149
-44
lines changed

1 file changed

+149
-44
lines changed

pina/model/spline.py

Lines changed: 149 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ class Spline(torch.nn.Module):
99
Spline model class.
1010
"""
1111

12-
def __init__(self, order=4, knots=None, control_points=None) -> None:
12+
def __init__(
13+
self, order=4, knots=None, control_points=None, grid_extension=True
14+
):
1315
"""
1416
Initialization of the :class:`Spline` class.
1517
@@ -20,7 +22,7 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
2022
``None``.
2123
:raises ValueError: If the order is negative.
2224
:raises ValueError: If both knots and control points are ``None``.
23-
:raises ValueError: If the knot tensor is not one-dimensional.
25+
:raises ValueError: If the knot tensor is not one or two dimensional.
2426
"""
2527
super().__init__()
2628

@@ -33,6 +35,10 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
3335

3436
self.order = order
3537
self.k = order - 1
38+
self.grid_extension = grid_extension
39+
40+
# Cache for performance optimization
41+
self._boundary_interval_idx = None
3642

3743
if knots is not None and control_points is not None:
3844
self.knots = knots
@@ -65,45 +71,154 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
6571
else:
6672
raise ValueError("Knots and control points cannot be both None.")
6773

68-
if self.knots.ndim != 1:
69-
raise ValueError("Knot vector must be one-dimensional.")
74+
if self.knots.ndim > 2:
75+
raise ValueError("Knot vector must be one or two-dimensional.")
76+
77+
# Precompute boundary interval index for performance
78+
self._compute_boundary_interval()
7079

71-
def basis(self, x, k, i, t):
80+
def _compute_boundary_interval(self):
81+
"""
82+
Precompute the rightmost non-degenerate interval index for performance.
83+
This avoids the search loop in the basis function on every call.
7284
"""
73-
Recursive method to compute the basis functions of the spline.
85+
# Handle multi-dimensional knots
86+
if self.knots.ndim > 1:
87+
# For multi-dimensional knots, we'll handle boundary detection in
88+
# the basis function
89+
self._boundary_interval_idx = None
90+
return
91+
92+
# For 1D knots, find the rightmost non-degenerate interval
93+
for i in range(len(self.knots) - 2, -1, -1):
94+
if self.knots[i] < self.knots[i + 1]: # Non-degenerate interval found
95+
self._boundary_interval_idx = i
96+
return
97+
98+
self._boundary_interval_idx = len(self.knots) - 2 if len(self.knots) > 1 else 0
99+
100+
def basis(self, x, k, knots):
101+
"""
102+
Compute the basis functions for the spline using an iterative approach.
103+
This is a vectorized implementation based on the Cox-de Boor recursion.
74104
75105
:param torch.Tensor x: The points to be evaluated.
76106
:param int k: The spline degree.
77-
:param int i: The index of the interval.
78-
:param torch.Tensor t: The tensor of knots.
107+
:param torch.Tensor knots: The tensor of knots.
79108
:return: The basis functions evaluated at x
80109
:rtype: torch.Tensor
81110
"""
82111

83-
if k == 0:
84-
a = torch.where(
85-
torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0
112+
if x.ndim == 1:
113+
x = x.unsqueeze(1) # (batch_size, 1)
114+
if x.ndim == 2:
115+
x = x.unsqueeze(2) # (batch_size, in_dim, 1)
116+
117+
if knots.ndim == 1:
118+
knots = knots.unsqueeze(0) # (1, n_knots)
119+
if knots.ndim == 2:
120+
knots = knots.unsqueeze(0) # (1, in_dim, n_knots)
121+
122+
# Base case: k=0
123+
basis = (x >= knots[..., :-1]) & (x < knots[..., 1:])
124+
basis = basis.to(x.dtype)
125+
126+
if self._boundary_interval_idx is not None:
127+
i = self._boundary_interval_idx
128+
tolerance = 1e-10
129+
x_squeezed = x.squeeze(-1)
130+
knot_left = knots[..., i]
131+
knot_right = knots[..., i + 1]
132+
133+
at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance
134+
in_rightmost_interval = (
135+
x_squeezed >= knot_left
136+
) & at_right_boundary
137+
138+
if torch.any(in_rightmost_interval):
139+
# For points at the boundary, ensure they're included in the
140+
# rightmost interval
141+
basis[..., i] = torch.logical_or(
142+
basis[..., i].bool(), in_rightmost_interval
143+
).to(basis.dtype)
144+
145+
# Iterative step (Cox-de Boor recursion)
146+
for i in range(1, k + 1):
147+
# First term of the recursion
148+
denom1 = knots[..., i:-1] - knots[..., : -(i + 1)]
149+
denom1 = torch.where(
150+
torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1
86151
)
87-
if i == len(t) - self.order - 1:
88-
a = torch.where(x == t[-1], 1.0, a)
89-
a.requires_grad_(True)
90-
return a
152+
numer1 = x - knots[..., : -(i + 1)]
153+
term1 = (numer1 / denom1) * basis[..., :-1]
91154

92-
if t[i + k] == t[i]:
93-
c1 = torch.tensor([0.0] * len(x), requires_grad=True)
94-
else:
95-
c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t)
155+
denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
156+
denom2 = torch.where(
157+
torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2
158+
)
159+
numer2 = knots[..., i + 1 :] - x
160+
term2 = (numer2 / denom2) * basis[..., 1:]
161+
162+
basis = term1 + term2
163+
164+
return basis
165+
166+
def compute_control_points(self, x_eval, y_eval):
167+
"""
168+
Compute control points from given evaluations using least squares.
169+
This method fits the control points to match the target y_eval values.
170+
"""
171+
# (batch, in_dim)
172+
A = self.basis(x_eval, self.k, self.knots)
173+
# (batch, in_dim, n_basis)
96174

97-
if t[i + k + 1] == t[i + 1]:
98-
c2 = torch.tensor([0.0] * len(x), requires_grad=True)
175+
in_dim = A.shape[1]
176+
out_dim = y_eval.shape[2]
177+
n_basis = A.shape[2]
178+
c = torch.zeros(in_dim, out_dim, n_basis).to(A.device)
179+
180+
for i in range(in_dim):
181+
# A_i is (batch, n_basis)
182+
# y_i is (batch, out_dim)
183+
A_i = A[:, i, :]
184+
y_i = y_eval[:, i, :]
185+
c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim)
186+
c[i, :, :] = c_i.T # (out_dim, n_basis)
187+
188+
self.control_points = torch.nn.Parameter(c)
189+
190+
def forward(self, x):
191+
"""
192+
Forward pass for the :class:`Spline` model.
193+
194+
:param torch.Tensor x: The input tensor.
195+
:return: The output tensor.
196+
:rtype: torch.Tensor
197+
"""
198+
t = self.knots
199+
k = self.k
200+
c = self.control_points
201+
202+
# Create the basis functions
203+
# B will have shape (batch, in_dim, n_basis)
204+
B = self.basis(x, k, t)
205+
206+
# KAN case where control points are (in_dim, out_dim, n_basis)
207+
if c.ndim == 3:
208+
y_ij = torch.einsum(
209+
"bil,iol->bio", B, c
210+
) # (batch, in_dim, out_dim)
211+
# sum over input dimensions
212+
y = torch.sum(y_ij, dim=1) # (batch, out_dim)
213+
# Original test case
99214
else:
100-
c2 = (
101-
(t[i + k + 1] - x)
102-
/ (t[i + k + 1] - t[i + 1])
103-
* self.basis(x, k - 1, i + 1, t)
104-
)
215+
B = B.squeeze(1) # (batch, n_basis)
216+
if c.ndim == 1:
217+
y = torch.einsum("bi,i->b", B, c)
218+
else:
219+
y = torch.einsum("bi,ij->bj", B, c)
105220

106-
return c1 + c2
221+
return y
107222

108223
@property
109224
def control_points(self):
@@ -131,9 +246,12 @@ def control_points(self, value):
131246
dim = value.get("dim", 1)
132247
value = torch.zeros(n, dim)
133248

249+
if not isinstance(value, torch.nn.Parameter):
250+
value = torch.nn.Parameter(value)
251+
134252
if not isinstance(value, torch.Tensor):
135253
raise ValueError("Invalid value for control_points")
136-
self._control_points = torch.nn.Parameter(value, requires_grad=True)
254+
self._control_points = value
137255

138256
@property
139257
def knots(self):
@@ -181,19 +299,6 @@ def knots(self, value):
181299

182300
self._knots = value
183301

184-
def forward(self, x):
185-
"""
186-
Forward pass for the :class:`Spline` model.
187-
188-
:param torch.Tensor x: The input tensor.
189-
:return: The output tensor.
190-
:rtype: torch.Tensor
191-
"""
192-
t = self.knots
193-
k = self.k
194-
c = self.control_points
195-
196-
basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c)))
197-
y = (torch.cat(list(basis), dim=1) * c).sum(axis=1)
198-
199-
return y
302+
# Recompute boundary interval when knots change
303+
if hasattr(self, "_boundary_interval_idx"):
304+
self._compute_boundary_interval()

0 commit comments

Comments
 (0)