@@ -682,7 +682,7 @@ def get_spline_derivatives_wrt_y(
682
682
683
683
684
684
@primitive
685
- def interpolate_spline (
685
+ def _interpolate_spline (
686
686
x_points : NDArray ,
687
687
y_points : NDArray ,
688
688
num_points : int ,
@@ -692,46 +692,9 @@ def interpolate_spline(
692
692
"""Primitive function to perform spline interpolation of a given order
693
693
with optional endpoint derivatives.
694
694
695
- Parameters
696
- ----------
697
- x_points : np.ndarray
698
- X coordinates of the data points (must be strictly monotonic)
699
- y_points : np.ndarray
700
- Y coordinates of the data points
701
- num_points : int
702
- Number of points in the output interpolation
703
- order : int
704
- Order of the spline (1=linear, 2=quadratic, 3=cubic)
705
- endpoint_derivatives : tuple[float, float] = (None, None)
706
- Derivatives at the endpoints (left, right)
707
- Note: For order=1 (linear), all endpoint derivatives are ignored.
708
- For order=2 (quadratic), only the left endpoint derivative is used.
709
- For order=3 (cubic), both endpoint derivatives are used if provided.
710
-
711
- Returns
712
- -------
713
- tuple[np.ndarray, np.ndarray]
714
- Tuple of (x_interpolated, y_interpolated) values
715
-
716
- Examples
717
- --------
718
- >>> import numpy as np
719
- >>> x = np.array([0, 1, 2])
720
- >>> y = np.array([0, 1, 0])
721
- >>> # Linear interpolation
722
- >>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, order=1)
723
- >>> print(y_interp)
724
- [0. 0.5 1. 0.5 0. ]
725
-
726
- >>> # Quadratic interpolation with left endpoint derivative
727
- >>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, None), order=2)
728
- >>> print(np.round(y_interp, 3))
729
- [0. 0.75 1. 0.5 0. ]
730
-
731
- >>> # Cubic interpolation with both endpoint derivatives
732
- >>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, 0), order=3)
733
- >>> print(np.round(y_interp, 3))
734
- [0. 0.75 1. 0.75 0. ]
695
+ Autograd requires that arguments to primitives are passed in positionally.
696
+ ``interpolate_spline`` is the public-facing wrapper for this function,
697
+ which allows keyword arguments in case users pass in kwargs.
735
698
"""
736
699
if order not in (1 , 2 , 3 ):
737
700
raise NotImplementedError (f"Spline order '{ order } ' not implemented." )
@@ -810,4 +773,64 @@ def vjp(g):
810
773
return vjp
811
774
812
775
813
- defvjp (interpolate_spline , None , interpolate_spline_y_vjp )
776
+ defvjp (_interpolate_spline , None , interpolate_spline_y_vjp )
777
+
778
+
779
+ def interpolate_spline (
780
+ x_points : NDArray ,
781
+ y_points : NDArray ,
782
+ num_points : int ,
783
+ order : int ,
784
+ endpoint_derivatives : tuple [Optional [float ], Optional [float ]] = (None , None ),
785
+ ) -> tuple [NDArray , NDArray ]:
786
+ """Differentiable spline interpolation of a given order
787
+ with optional endpoint derivatives.
788
+
789
+ Parameters
790
+ ----------
791
+ x_points : np.ndarray
792
+ X coordinates of the data points (must be strictly monotonic)
793
+ y_points : np.ndarray
794
+ Y coordinates of the data points
795
+ num_points : int
796
+ Number of points in the output interpolation
797
+ order : int
798
+ Order of the spline (1=linear, 2=quadratic, 3=cubic)
799
+ endpoint_derivatives : tuple[float, float] = (None, None)
800
+ Derivatives at the endpoints (left, right)
801
+ Note: For order=1 (linear), all endpoint derivatives are ignored.
802
+ For order=2 (quadratic), only the left endpoint derivative is used.
803
+ For order=3 (cubic), both endpoint derivatives are used if provided.
804
+
805
+ Returns
806
+ -------
807
+ tuple[np.ndarray, np.ndarray]
808
+ Tuple of (x_interpolated, y_interpolated) values
809
+
810
+ Examples
811
+ --------
812
+ >>> import numpy as np
813
+ >>> x = np.array([0, 1, 2])
814
+ >>> y = np.array([0, 1, 0])
815
+ >>> # Linear interpolation
816
+ >>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, order=1)
817
+ >>> print(y_interp)
818
+ [0. 0.5 1. 0.5 0. ]
819
+
820
+ >>> # Quadratic interpolation with left endpoint derivative
821
+ >>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, None), order=2)
822
+ >>> print(np.round(y_interp, 3))
823
+ [0. 0.75 1. 0.5 0. ]
824
+
825
+ >>> # Cubic interpolation with both endpoint derivatives
826
+ >>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, 0), order=3)
827
+ >>> print(np.round(y_interp, 3))
828
+ [0. 0.75 1. 0.75 0. ]
829
+ """
830
+ return _interpolate_spline (
831
+ x_points ,
832
+ y_points ,
833
+ num_points ,
834
+ order ,
835
+ endpoint_derivatives ,
836
+ )
0 commit comments