@@ -139,6 +139,9 @@ def __init__(self, order=4, knots=None, control_points=None):
139139 # Precompute boundary interval index
140140 self ._boundary_interval_idx = self ._compute_boundary_interval ()
141141
142+ # Precompute denominators used in derivative formulas
143+ self ._compute_derivative_denominators ()
144+
142145 def _compute_boundary_interval (self ):
143146 """
144147 Precompute the index of the rightmost non-degenerate interval to improve
@@ -163,15 +166,49 @@ def _compute_boundary_interval(self):
163166 # Otherwise, return the last valid index
164167 return int (valid [- 1 ])
165168
166- def basis (self , x ):
169+ def _compute_derivative_denominators (self ):
170+ """
171+ Precompute the denominators used in the derivatives for all orders up to
172+ the spline order to avoid redundant calculations.
173+ """
174+ # Precompute for orders 2 to k
175+ for i in range (2 , self .order + 1 ):
176+
177+ # Denominators for the derivative recurrence relations
178+ left_den = self .knots [i - 1 : - 1 ] - self .knots [:- i ]
179+ right_den = self .knots [i :] - self .knots [1 : - i + 1 ]
180+
181+ # If consecutive knots are equal, set left and right factors to zero
182+ left_fac = torch .where (
183+ torch .abs (left_den ) > 1e-10 ,
184+ (i - 1 ) / left_den ,
185+ torch .zeros_like (left_den ),
186+ )
187+ right_fac = torch .where (
188+ torch .abs (right_den ) > 1e-10 ,
189+ (i - 1 ) / right_den ,
190+ torch .zeros_like (right_den ),
191+ )
192+
193+ # Register buffers
194+ self .register_buffer (f"_left_factor_order_{ i } " , left_fac )
195+ self .register_buffer (f"_right_factor_order_{ i } " , right_fac )
196+
197+ def basis (self , x , collection = False ):
167198 """
168199 Compute the basis functions for the spline using an iterative approach.
169200 This is a vectorized implementation based on the Cox-de Boor recursion.
170201
171202 :param torch.Tensor x: The points to be evaluated.
203+ :param bool collection: If True, returns a list of basis functions for
204+ all orders up to the spline order. Default is False.
205+ :raise ValueError: If ``collection`` is not a boolean.
172206 :return: The basis functions evaluated at x.
173- :rtype: torch.Tensor
207+ :rtype: torch.Tensor | list[torch.Tensor]
174208 """
209+ # Check consistency
210+ check_consistency (collection , bool )
211+
175212 # Add a final dimension to x
176213 x = x .unsqueeze (- 1 )
177214
@@ -201,6 +238,10 @@ def basis(self, x):
201238 at_rightmost_boundary ,
202239 ).to (basis .dtype )
203240
241+ # If returning the whole collection, initialize list
242+ if collection :
243+ basis_collection = [None , basis ]
244+
204245 # Iterative case of recursion
205246 for i in range (1 , self .order ):
206247
@@ -222,8 +263,10 @@ def basis(self, x):
222263
223264 # Combine terms to get the new basis
224265 basis = term1 + term2
266+ if collection :
267+ basis_collection .append (basis )
225268
226- return basis
269+ return basis_collection if collection else basis
227270
228271 def forward (self , x ):
229272 """
@@ -240,6 +283,72 @@ def forward(self, x):
240283 self .control_points ,
241284 )
242285
286+ def derivative (self , x , degree ):
287+ """
288+ Compute the ``degree``-th derivative of the spline at given points.
289+
290+ :param x: The input tensor.
291+ :type x: torch.Tensor | LabelTensor
292+ :param int degree: The derivative degree to compute.
293+ :raise ValueError: If ``degree`` is not an integer.
294+ :return: The derivative tensor.
295+ :rtype: torch.Tensor
296+ """
297+ # Check consistency
298+ check_positive_integer (degree , strict = False )
299+
300+ # Compute basis derivative
301+ der = self ._basis_derivative (x .as_subclass (torch .Tensor ), degree = degree )
302+
303+ return torch .einsum ("...bi, i -> ...b" , der , self .control_points )
304+
305+ def _basis_derivative (self , x , degree ):
306+ """
307+ Compute the ``degree``-th derivative of the spline basis functions at
308+ given points using an iterative approach.
309+
310+ :param torch.Tensor x: The points to be evaluated.
311+ :param int degree: The derivative degree to compute.
312+ :return: The basis functions evaluated at x.
313+ :rtype: torch.Tensor
314+ """
315+ # Compute the whole basis collection
316+ basis = self .basis (x , collection = True )
317+
318+ # Derivatives initialization (with dummy at index 0 for convenience)
319+ derivatives = [None ] + [basis [o ] for o in range (1 , self .order + 1 )]
320+
321+ # Iterate over derivative degrees
322+ for _ in range (1 , degree + 1 ):
323+
324+ # Current degree derivatives (with dummy at index 0 for convenience)
325+ current_der = [None ] * (self .order + 1 )
326+ current_der [1 ] = torch .zeros_like (derivatives [1 ])
327+
328+ # Iterate over basis orders
329+ for o in range (2 , self .order + 1 ):
330+
331+ # Retrieve precomputed factors
332+ left_fac = getattr (self , f"_left_factor_order_{ o } " )
333+ right_fac = getattr (self , f"_right_factor_order_{ o } " )
334+
335+ # Slice previous derivatives to align
336+ left_part = derivatives [o - 1 ][..., :- 1 ]
337+ right_part = derivatives [o - 1 ][..., 1 :]
338+
339+ # Broadcast factors over batch dims
340+ view_shape = (1 ,) * (left_part .ndim - 1 ) + (- 1 ,)
341+ left_fac = left_fac .reshape (* view_shape )
342+ right_fac = right_fac .reshape (* view_shape )
343+
344+ # Compute current derivatives
345+ current_der [o ] = left_fac * left_part - right_fac * right_part
346+
347+ # Update derivatives for next degree
348+ derivatives = current_der
349+
350+ return derivatives [self .order ].squeeze (- 1 )
351+
243352 @property
244353 def control_points (self ):
245354 """
@@ -364,3 +473,6 @@ def knots(self, value):
364473 # Recompute boundary interval when knots change
365474 if hasattr (self , "_boundary_interval_idx" ):
366475 self ._boundary_interval_idx = self ._compute_boundary_interval ()
476+
477+ # Recompute derivative denominators when knots change
478+ self ._compute_derivative_denominators ()
0 commit comments