@@ -9,7 +9,9 @@ class Spline(torch.nn.Module):
99 Spline model class.
1010 """
1111
12- def __init__ (self , order = 4 , knots = None , control_points = None ) -> None :
12+ def __init__ (
13+ self , order = 4 , knots = None , control_points = None , grid_extension = True
14+ ):
1315 """
1416 Initialization of the :class:`Spline` class.
1517
@@ -20,7 +22,7 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
2022 ``None``.
2123 :raises ValueError: If the order is negative.
2224 :raises ValueError: If both knots and control points are ``None``.
23- :raises ValueError: If the knot tensor is not one- dimensional.
25+ :raises ValueError: If the knot tensor is not one or two dimensional.
2426 """
2527 super ().__init__ ()
2628
@@ -33,6 +35,10 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
3335
3436 self .order = order
3537 self .k = order - 1
38+ self .grid_extension = grid_extension
39+
40+ # Cache for performance optimization
41+ self ._boundary_interval_idx = None
3642
3743 if knots is not None and control_points is not None :
3844 self .knots = knots
@@ -65,45 +71,154 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
6571 else :
6672 raise ValueError ("Knots and control points cannot be both None." )
6773
68- if self .knots .ndim != 1 :
69- raise ValueError ("Knot vector must be one-dimensional." )
74+ if self .knots .ndim > 2 :
75+ raise ValueError ("Knot vector must be one or two-dimensional." )
76+
77+ # Precompute boundary interval index for performance
78+ self ._compute_boundary_interval ()
7079
71- def basis (self , x , k , i , t ):
80+ def _compute_boundary_interval (self ):
81+ """
82+ Precompute the rightmost non-degenerate interval index for performance.
83+ This avoids the search loop in the basis function on every call.
7284 """
73- Recursive method to compute the basis functions of the spline.
85+ # Handle multi-dimensional knots
86+ if self .knots .ndim > 1 :
87+ # For multi-dimensional knots, we'll handle boundary detection in
88+ # the basis function
89+ self ._boundary_interval_idx = None
90+ return
91+
92+ # For 1D knots, find the rightmost non-degenerate interval
93+ for i in range (len (self .knots ) - 2 , - 1 , - 1 ):
94+ if self .knots [i ] < self .knots [i + 1 ]: # Non-degenerate interval found
95+ self ._boundary_interval_idx = i
96+ return
97+
98+ self ._boundary_interval_idx = len (self .knots ) - 2 if len (self .knots ) > 1 else 0
99+
100+ def basis (self , x , k , knots ):
101+ """
102+ Compute the basis functions for the spline using an iterative approach.
103+ This is a vectorized implementation based on the Cox-de Boor recursion.
74104
75105 :param torch.Tensor x: The points to be evaluated.
76106 :param int k: The spline degree.
77- :param int i: The index of the interval.
78- :param torch.Tensor t: The tensor of knots.
107+ :param torch.Tensor knots: The tensor of knots.
79108 :return: The basis functions evaluated at x
80109 :rtype: torch.Tensor
81110 """
82111
83- if k == 0 :
84- a = torch .where (
85- torch .logical_and (t [i ] <= x , x < t [i + 1 ]), 1.0 , 0.0
112+ if x .ndim == 1 :
113+ x = x .unsqueeze (1 ) # (batch_size, 1)
114+ if x .ndim == 2 :
115+ x = x .unsqueeze (2 ) # (batch_size, in_dim, 1)
116+
117+ if knots .ndim == 1 :
118+ knots = knots .unsqueeze (0 ) # (1, n_knots)
119+ if knots .ndim == 2 :
120+ knots = knots .unsqueeze (0 ) # (1, in_dim, n_knots)
121+
122+ # Base case: k=0
123+ basis = (x >= knots [..., :- 1 ]) & (x < knots [..., 1 :])
124+ basis = basis .to (x .dtype )
125+
126+ if self ._boundary_interval_idx is not None :
127+ i = self ._boundary_interval_idx
128+ tolerance = 1e-10
129+ x_squeezed = x .squeeze (- 1 )
130+ knot_left = knots [..., i ]
131+ knot_right = knots [..., i + 1 ]
132+
133+ at_right_boundary = torch .abs (x_squeezed - knot_right ) <= tolerance
134+ in_rightmost_interval = (
135+ x_squeezed >= knot_left
136+ ) & at_right_boundary
137+
138+ if torch .any (in_rightmost_interval ):
139+ # For points at the boundary, ensure they're included in the
140+ # rightmost interval
141+ basis [..., i ] = torch .logical_or (
142+ basis [..., i ].bool (), in_rightmost_interval
143+ ).to (basis .dtype )
144+
145+ # Iterative step (Cox-de Boor recursion)
146+ for i in range (1 , k + 1 ):
147+ # First term of the recursion
148+ denom1 = knots [..., i :- 1 ] - knots [..., : - (i + 1 )]
149+ denom1 = torch .where (
150+ torch .abs (denom1 ) < 1e-8 , torch .ones_like (denom1 ), denom1
86151 )
87- if i == len (t ) - self .order - 1 :
88- a = torch .where (x == t [- 1 ], 1.0 , a )
89- a .requires_grad_ (True )
90- return a
152+ numer1 = x - knots [..., : - (i + 1 )]
153+ term1 = (numer1 / denom1 ) * basis [..., :- 1 ]
91154
92- if t [i + k ] == t [i ]:
93- c1 = torch .tensor ([0.0 ] * len (x ), requires_grad = True )
94- else :
95- c1 = (x - t [i ]) / (t [i + k ] - t [i ]) * self .basis (x , k - 1 , i , t )
155+ denom2 = knots [..., i + 1 :] - knots [..., 1 :- i ]
156+ denom2 = torch .where (
157+ torch .abs (denom2 ) < 1e-8 , torch .ones_like (denom2 ), denom2
158+ )
159+ numer2 = knots [..., i + 1 :] - x
160+ term2 = (numer2 / denom2 ) * basis [..., 1 :]
161+
162+ basis = term1 + term2
163+
164+ return basis
165+
166+ def compute_control_points (self , x_eval , y_eval ):
167+ """
168+ Compute control points from given evaluations using least squares.
169+ This method fits the control points to match the target y_eval values.
170+ """
171+ # (batch, in_dim)
172+ A = self .basis (x_eval , self .k , self .knots )
173+ # (batch, in_dim, n_basis)
96174
97- if t [i + k + 1 ] == t [i + 1 ]:
98- c2 = torch .tensor ([0.0 ] * len (x ), requires_grad = True )
175+ in_dim = A .shape [1 ]
176+ out_dim = y_eval .shape [2 ]
177+ n_basis = A .shape [2 ]
178+ c = torch .zeros (in_dim , out_dim , n_basis ).to (A .device )
179+
180+ for i in range (in_dim ):
181+ # A_i is (batch, n_basis)
182+ # y_i is (batch, out_dim)
183+ A_i = A [:, i , :]
184+ y_i = y_eval [:, i , :]
185+ c_i = torch .linalg .lstsq (A_i , y_i ).solution # (n_basis, out_dim)
186+ c [i , :, :] = c_i .T # (out_dim, n_basis)
187+
188+ self .control_points = torch .nn .Parameter (c )
189+
190+ def forward (self , x ):
191+ """
192+ Forward pass for the :class:`Spline` model.
193+
194+ :param torch.Tensor x: The input tensor.
195+ :return: The output tensor.
196+ :rtype: torch.Tensor
197+ """
198+ t = self .knots
199+ k = self .k
200+ c = self .control_points
201+
202+ # Create the basis functions
203+ # B will have shape (batch, in_dim, n_basis)
204+ B = self .basis (x , k , t )
205+
206+ # KAN case where control points are (in_dim, out_dim, n_basis)
207+ if c .ndim == 3 :
208+ y_ij = torch .einsum (
209+ "bil,iol->bio" , B , c
210+ ) # (batch, in_dim, out_dim)
211+ # sum over input dimensions
212+ y = torch .sum (y_ij , dim = 1 ) # (batch, out_dim)
213+ # Original test case
99214 else :
100- c2 = (
101- ( t [ i + k + 1 ] - x )
102- / ( t [ i + k + 1 ] - t [ i + 1 ] )
103- * self . basis ( x , k - 1 , i + 1 , t )
104- )
215+ B = B . squeeze ( 1 ) # (batch, n_basis)
216+ if c . ndim == 1 :
217+ y = torch . einsum ( "bi,i->b" , B , c )
218+ else :
219+ y = torch . einsum ( "bi,ij->bj" , B , c )
105220
106- return c1 + c2
221+ return y
107222
108223 @property
109224 def control_points (self ):
@@ -131,9 +246,12 @@ def control_points(self, value):
131246 dim = value .get ("dim" , 1 )
132247 value = torch .zeros (n , dim )
133248
249+ if not isinstance (value , torch .nn .Parameter ):
250+ value = torch .nn .Parameter (value )
251+
134252 if not isinstance (value , torch .Tensor ):
135253 raise ValueError ("Invalid value for control_points" )
136- self ._control_points = torch . nn . Parameter ( value , requires_grad = True )
254+ self ._control_points = value
137255
138256 @property
139257 def knots (self ):
@@ -181,19 +299,6 @@ def knots(self, value):
181299
182300 self ._knots = value
183301
184- def forward (self , x ):
185- """
186- Forward pass for the :class:`Spline` model.
187-
188- :param torch.Tensor x: The input tensor.
189- :return: The output tensor.
190- :rtype: torch.Tensor
191- """
192- t = self .knots
193- k = self .k
194- c = self .control_points
195-
196- basis = map (lambda i : self .basis (x , k , i , t )[:, None ], range (len (c )))
197- y = (torch .cat (list (basis ), dim = 1 ) * c ).sum (axis = 1 )
198-
199- return y
302+ # Recompute boundary interval when knots change
303+ if hasattr (self , "_boundary_interval_idx" ):
304+ self ._compute_boundary_interval ()
0 commit comments