Skip to content

Commit 7b53ee9

Browse files
committed
feat[adjoint]: allow traceable kwargs in primitives interpolate_spline and add_at
1 parent e8ec6e0 commit 7b53ee9

File tree

5 files changed

+124
-60
lines changed

5 files changed

+124
-60
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2929
- Adjoint source frequency width is adjusted to decay sufficiently before zero frequency when possible to improve accuracy of simulation normalization when using custom current sources.
3030
- Change `VisualizationSpec` validator for checking validity of user specified colors to only issue a warning if matplotlib is not installed instead of an error.
3131

32+
### Changed
33+
- `tidy3d.plugins.autograd.interpolate_spline()` and `tidy3d.plugins.autograd.add_at()` can now be called with keyword arguments during tracing.
34+
3235
## [2.8.4] - 2025-05-15
3336

3437
### Added

tests/test_plugins/autograd/primitives/test_interpolate.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import autograd.numpy as np
22
import numpy.testing as npt
33
import pytest
4+
from autograd import grad
45
from autograd.test_util import check_grads
56
from tidy3d.plugins.autograd import interpolate_spline
67

@@ -36,6 +37,22 @@ def test_interpolate_spline_grads(rng, order, num_points, endpoint_derivs, x_dis
3637
)
3738

3839

40+
def test_interpolate_spline_grads_kwargs(rng):
41+
"""Test interpolate_spline function can be called with kwargs."""
42+
x = np.linspace(0, 1, 10)
43+
y = rng.random(x.size)
44+
# this should not error
45+
grad(
46+
lambda y_: interpolate_spline(
47+
x_points=x,
48+
y_points=y_,
49+
num_points=10,
50+
order=3,
51+
endpoint_derivatives=(None, None),
52+
)[1][0]
53+
)(y)
54+
55+
3956
@pytest.mark.parametrize("order", [1, 2, 3])
4057
@pytest.mark.parametrize("x", [np.linspace(0, 1, 10), np.linspace(1, 0, 10)])
4158
def test_interpolate_spline_vals(rng, order, x):

tests/test_plugins/autograd/test_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import scipy.interpolate
55
import scipy.ndimage
6+
from autograd import grad
67
from autograd.test_util import check_grads
78
from scipy.signal import convolve as convolve_sp
89
from tidy3d.plugins.autograd import (
@@ -387,6 +388,15 @@ def test_add_at_grad(self, rng, shape, indices):
387388
check_grads(lambda y: add_at(x, indices, y), modes=["fwd", "rev"], order=2)(y)
388389

389390

391+
def test_add_at_grad_kwargs(rng):
392+
"""Test add_at function for different array dimensions and indices, with kwargs."""
393+
indices = (0,)
394+
x = rng.uniform(-1, 1, (10,))
395+
y = rng.uniform(-1, 1, x[tuple(indices)].shape)
396+
# this should not error
397+
grad(lambda y_: add_at(x=x, y=y_, indices_x=indices)[0])(y)
398+
399+
390400
@pytest.mark.parametrize("shape", [(5,), (5, 5), (5, 5, 5)])
391401
@pytest.mark.parametrize("tau", [1e-3, 1.0])
392402
@pytest.mark.parametrize("axis", [None, 0, 1, -1])

tidy3d/components/autograd/functions.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,34 @@ def trapz(y: NDArray, x: NDArray = None, dx: float = 1.0, axis: int = -1) -> flo
198198

199199

200200
@primitive
201+
def _add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray:
202+
"""
203+
Add values to specified indices of an array.
204+
205+
Autograd requires that arguments to primitives are passed in positionally.
206+
``add_at`` is the public-facing wrapper for this function,
207+
which allows keyword arguments in case users pass in kwargs.
208+
"""
209+
out = np.copy(x) # Copy to preserve 'x' for gradient computation
210+
out[tuple(indices_x)] += y
211+
return out
212+
213+
214+
defvjp(
215+
_add_at,
216+
lambda ans, x, indices_x, y: unbroadcast_f(x, lambda g: g),
217+
lambda ans, x, indices_x, y: lambda g: g[tuple(indices_x)],
218+
argnums=(0, 2),
219+
)
220+
221+
defjvp(
222+
_add_at,
223+
lambda g, ans, x, indices_x, y: broadcast(g, ans),
224+
lambda g, ans, x, indices_x, y: _add_at(anp.zeros_like(ans), indices_x, g),
225+
argnums=(0, 2),
226+
)
227+
228+
201229
def add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray:
202230
"""
203231
Add values to specified indices of an array.
@@ -219,24 +247,7 @@ def add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray:
219247
np.ndarray
220248
The modified array with values added at the specified indices.
221249
"""
222-
out = np.copy(x) # Copy to preserve 'x' for gradient computation
223-
out[tuple(indices_x)] += y
224-
return out
225-
226-
227-
defvjp(
228-
add_at,
229-
lambda ans, x, indices_x, y: unbroadcast_f(x, lambda g: g),
230-
lambda ans, x, indices_x, y: lambda g: g[tuple(indices_x)],
231-
argnums=(0, 2),
232-
)
233-
234-
defjvp(
235-
add_at,
236-
lambda g, ans, x, indices_x, y: broadcast(g, ans),
237-
lambda g, ans, x, indices_x, y: add_at(anp.zeros_like(ans), indices_x, g),
238-
argnums=(0, 2),
239-
)
250+
return _add_at(x, indices_x, y)
240251

241252

242253
__all__ = [

tidy3d/plugins/autograd/primitives/interpolate.py

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def get_spline_derivatives_wrt_y(
682682

683683

684684
@primitive
685-
def interpolate_spline(
685+
def _interpolate_spline(
686686
x_points: NDArray,
687687
y_points: NDArray,
688688
num_points: int,
@@ -692,46 +692,9 @@ def interpolate_spline(
692692
"""Primitive function to perform spline interpolation of a given order
693693
with optional endpoint derivatives.
694694
695-
Parameters
696-
----------
697-
x_points : np.ndarray
698-
X coordinates of the data points (must be strictly monotonic)
699-
y_points : np.ndarray
700-
Y coordinates of the data points
701-
num_points : int
702-
Number of points in the output interpolation
703-
order : int
704-
Order of the spline (1=linear, 2=quadratic, 3=cubic)
705-
endpoint_derivatives : tuple[float, float] = (None, None)
706-
Derivatives at the endpoints (left, right)
707-
Note: For order=1 (linear), all endpoint derivatives are ignored.
708-
For order=2 (quadratic), only the left endpoint derivative is used.
709-
For order=3 (cubic), both endpoint derivatives are used if provided.
710-
711-
Returns
712-
-------
713-
tuple[np.ndarray, np.ndarray]
714-
Tuple of (x_interpolated, y_interpolated) values
715-
716-
Examples
717-
--------
718-
>>> import numpy as np
719-
>>> x = np.array([0, 1, 2])
720-
>>> y = np.array([0, 1, 0])
721-
>>> # Linear interpolation
722-
>>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, order=1)
723-
>>> print(y_interp)
724-
[0. 0.5 1. 0.5 0. ]
725-
726-
>>> # Quadratic interpolation with left endpoint derivative
727-
>>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, None), order=2)
728-
>>> print(np.round(y_interp, 3))
729-
[0. 0.75 1. 0.5 0. ]
730-
731-
>>> # Cubic interpolation with both endpoint derivatives
732-
>>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, 0), order=3)
733-
>>> print(np.round(y_interp, 3))
734-
[0. 0.75 1. 0.75 0. ]
695+
Autograd requires that arguments to primitives are passed in positionally.
696+
``interpolate_spline`` is the public-facing wrapper for this function,
697+
which allows keyword arguments in case users pass in kwargs.
735698
"""
736699
if order not in (1, 2, 3):
737700
raise NotImplementedError(f"Spline order '{order}' not implemented.")
@@ -810,4 +773,64 @@ def vjp(g):
810773
return vjp
811774

812775

813-
defvjp(interpolate_spline, None, interpolate_spline_y_vjp)
776+
defvjp(_interpolate_spline, None, interpolate_spline_y_vjp)
777+
778+
779+
def interpolate_spline(
780+
x_points: NDArray,
781+
y_points: NDArray,
782+
num_points: int,
783+
order: int,
784+
endpoint_derivatives: tuple[Optional[float], Optional[float]] = (None, None),
785+
) -> tuple[NDArray, NDArray]:
786+
"""Differentiable spline interpolation of a given order
787+
with optional endpoint derivatives.
788+
789+
Parameters
790+
----------
791+
x_points : np.ndarray
792+
X coordinates of the data points (must be strictly monotonic)
793+
y_points : np.ndarray
794+
Y coordinates of the data points
795+
num_points : int
796+
Number of points in the output interpolation
797+
order : int
798+
Order of the spline (1=linear, 2=quadratic, 3=cubic)
799+
endpoint_derivatives : tuple[float, float] = (None, None)
800+
Derivatives at the endpoints (left, right)
801+
Note: For order=1 (linear), all endpoint derivatives are ignored.
802+
For order=2 (quadratic), only the left endpoint derivative is used.
803+
For order=3 (cubic), both endpoint derivatives are used if provided.
804+
805+
Returns
806+
-------
807+
tuple[np.ndarray, np.ndarray]
808+
Tuple of (x_interpolated, y_interpolated) values
809+
810+
Examples
811+
--------
812+
>>> import numpy as np
813+
>>> x = np.array([0, 1, 2])
814+
>>> y = np.array([0, 1, 0])
815+
>>> # Linear interpolation
816+
>>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, order=1)
817+
>>> print(y_interp)
818+
[0. 0.5 1. 0.5 0. ]
819+
820+
>>> # Quadratic interpolation with left endpoint derivative
821+
>>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, None), order=2)
822+
>>> print(np.round(y_interp, 3))
823+
[0. 0.75 1. 0.5 0. ]
824+
825+
>>> # Cubic interpolation with both endpoint derivatives
826+
>>> x_interp, y_interp = interpolate_spline(x, y, num_points=5, endpoint_derivatives=(0, 0), order=3)
827+
>>> print(np.round(y_interp, 3))
828+
[0. 0.75 1. 0.75 0. ]
829+
"""
830+
return _interpolate_spline(
831+
x_points,
832+
y_points,
833+
num_points,
834+
order,
835+
endpoint_derivatives,
836+
)

0 commit comments

Comments
 (0)