Skip to content

Commit 6da4072

Browse files
committed
feat: scalar_objective decorator to ensure autograd-compatibility of objective functions
1 parent 9381017 commit 6da4072

File tree

7 files changed

+262
-54
lines changed

7 files changed

+262
-54
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- Autograd support for local field projections using `FieldProjectionKSpaceMonitor`.
1212
- Function `components.geometry.utils.flatten_groups` now also flattens transformed groups when requested.
1313
- Differentiable `smooth_min`, `smooth_max`, and `least_squares` functions in `tidy3d.plugins.autograd`.
14+
- Differential operators `grad` and `value_and_grad` in `tidy3d.plugins.autograd` that behave similarly to the autograd operators but support auxiliary data via `aux_data=True` as well as differentiation w.r.t. `DataArray`.
15+
- `@scalar_objective` decorator in `tidy3d.plugins.autograd` that wraps objective functions to ensure they return a scalar value and performs additional checks to ensure compatibility of objective functions with autograd. Used by default in `tidy3d.plugins.autograd.value_and_grad` as well as `tidy3d.plugins.autograd.grad`.
16+
1417

1518
### Changed
1619
- `CustomMedium` design regions require far less data when performing inverse design by reducing adjoint field monitor size for dims with one pixel.
Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,66 @@
11
import autograd.numpy as np
2+
import pytest
3+
from autograd import grad as grad_ag
24
from autograd import value_and_grad as value_and_grad_ag
35
from numpy.testing import assert_allclose
4-
from tidy3d.plugins.autograd.differential_operators import value_and_grad
6+
from tidy3d.components.data.data_array import DataArray
7+
from tidy3d.plugins.autograd import grad, value_and_grad
58

69

7-
def test_value_and_grad(rng):
10+
@pytest.mark.parametrize("argnum", [0, 1])
11+
@pytest.mark.parametrize("has_aux", [True, False])
12+
def test_grad(rng, argnum, has_aux):
813
"""Test the custom value_and_grad function against autograd's implementation"""
914
x = rng.random(10)
15+
y = rng.random(10)
1016
aux_val = "aux"
1117

12-
vg_fun = value_and_grad(lambda x: (np.linalg.norm(x), aux_val), has_aux=True)
13-
vg_fun_ag = value_and_grad_ag(lambda x: np.linalg.norm(x))
18+
def f(x, y):
19+
ret = DataArray(x * y).sum() # still DataArray
20+
if has_aux:
21+
return ret, aux_val
22+
return ret
1423

15-
(v, g), aux = vg_fun(x)
16-
v_ag, g_ag = vg_fun_ag(x)
24+
grad_fun = grad(f, argnum=argnum, has_aux=has_aux)
25+
grad_fun_ag = grad_ag(
26+
lambda x, y: f(x, y)[0].item() if has_aux else f(x, y).item(), argnum=argnum
27+
)
28+
29+
if has_aux:
30+
g, aux = grad_fun(x, y)
31+
assert aux == aux_val
32+
else:
33+
g = grad_fun(x, y)
34+
g_ag = grad_fun_ag(x, y)
1735

18-
# assert that values and gradients match
19-
assert_allclose(v, v_ag)
2036
assert_allclose(g, g_ag)
2137

22-
# check that auxiliary output is correctly returned
23-
assert aux == aux_val
38+
39+
@pytest.mark.parametrize("argnum", [0, 1])
40+
@pytest.mark.parametrize("has_aux", [True, False])
41+
def test_value_and_grad(rng, argnum, has_aux):
42+
"""Test the custom value_and_grad function against autograd's implementation"""
43+
x = rng.random(10)
44+
y = rng.random(10)
45+
aux_val = "aux"
46+
47+
def f(x, y):
48+
ret = DataArray(np.linalg.norm(x * y)).sum() # still DataArray
49+
if has_aux:
50+
return ret, aux_val
51+
return ret
52+
53+
vg_fun = value_and_grad(f, argnum=argnum, has_aux=has_aux)
54+
vg_fun_ag = value_and_grad_ag(
55+
lambda x, y: f(x, y)[0].item() if has_aux else f(x, y).item(), argnum=argnum
56+
)
57+
58+
if has_aux:
59+
(v, g), aux = vg_fun(x, y)
60+
assert aux == aux_val
61+
else:
62+
v, g = vg_fun(x, y)
63+
v_ag, g_ag = vg_fun_ag(x, y)
64+
65+
assert_allclose(v, v_ag)
66+
assert_allclose(g, g_ag)

tests/test_plugins/autograd/test_utilities.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import numpy as np
22
import numpy.testing as npt
33
import pytest
4-
from tidy3d.plugins.autograd.utilities import chain, get_kernel_size_px, make_kernel
4+
import xarray as xr
5+
from tidy3d.exceptions import Tidy3dError
6+
from tidy3d.plugins.autograd import (
7+
chain,
8+
get_kernel_size_px,
9+
make_kernel,
10+
scalar_objective,
11+
value_and_grad,
12+
)
513

614

715
@pytest.mark.parametrize("size", [(3, 3), (4, 4), (5, 5)])
@@ -104,3 +112,55 @@ def add_one(x):
104112
funcs = [add_one, "not_a_function"]
105113
with pytest.raises(TypeError, match="All elements in funcs must be callable"):
106114
chain(funcs)
115+
116+
117+
class TestScalarObjective:
118+
def test_scalar_objective_no_aux(self):
119+
"""Test scalar_objective decorator without auxiliary data."""
120+
121+
@scalar_objective
122+
def objective(x):
123+
da = xr.DataArray(x)
124+
return da.sum()
125+
126+
x = np.array([1.0, 2.0, 3.0])
127+
result, grad = value_and_grad(objective)(x)
128+
assert np.allclose(grad, np.ones_like(grad))
129+
assert np.isclose(result, 6.0)
130+
131+
def test_scalar_objective_with_aux(self):
132+
"""Test scalar_objective decorator with auxiliary data."""
133+
134+
@scalar_objective(has_aux=True)
135+
def objective(x):
136+
da = xr.DataArray(x)
137+
return da.sum(), "auxiliary_data"
138+
139+
x = np.array([1.0, 2.0, 3.0])
140+
(result, grad), aux_data = value_and_grad(objective, has_aux=True)(x)
141+
assert np.allclose(grad, np.ones_like(grad))
142+
assert np.isclose(result, 6.0)
143+
assert aux_data == "auxiliary_data"
144+
145+
def test_scalar_objective_invalid_return(self):
146+
"""Test scalar_objective decorator with invalid return value."""
147+
148+
@scalar_objective
149+
def objective(x):
150+
da = xr.DataArray(x)
151+
return da # Returning the array directly, not a scalar
152+
153+
x = np.array([1, 2, 3])
154+
with pytest.raises(Tidy3dError, match="must be a scalar"):
155+
objective(x)
156+
157+
def test_scalar_objective_float(self):
158+
"""Test scalar_objective decorator with a Python float return value."""
159+
160+
@scalar_objective
161+
def objective(x):
162+
return x**2
163+
164+
result, grad = value_and_grad(objective)(3.0)
165+
assert np.isclose(grad, 6.0)
166+
assert np.isclose(result, 9.0)

tidy3d/plugins/autograd/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
from .differential_operators import value_and_grad
1+
from .differential_operators import grad, value_and_grad
22
from .functions import (
3+
add_at,
34
convolve,
45
grey_closing,
56
grey_dilation,
67
grey_erosion,
78
grey_opening,
9+
interpn,
10+
least_squares,
811
morphological_gradient,
912
morphological_gradient_external,
1013
morphological_gradient_internal,
1114
pad,
1215
rescale,
16+
smooth_max,
17+
smooth_min,
1318
threshold,
19+
trapz,
1420
)
1521
from .invdes import (
1622
CircularFilter,
@@ -28,7 +34,7 @@
2834
tanh_projection,
2935
)
3036
from .primitives import gaussian_filter
31-
from .utilities import chain, get_kernel_size_px, make_kernel
37+
from .utilities import chain, get_kernel_size_px, make_kernel, scalar_objective
3238

3339
__all__ = [
3440
"CircularFilter",
@@ -60,4 +66,12 @@
6066
"rescale",
6167
"threshold",
6268
"value_and_grad",
69+
"smooth_min",
70+
"smooth_max",
71+
"add_at",
72+
"interpn",
73+
"least_squares",
74+
"grad",
75+
"scalar_objective",
76+
"trapz",
6377
]
Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,72 @@
1-
from typing import Any, Callable
1+
from typing import Callable
22

3-
from autograd import value_and_grad as value_and_grad_ag
43
from autograd.builtins import tuple as atuple
54
from autograd.core import make_vjp
65
from autograd.extend import vspace
76
from autograd.wrap_util import unary_to_nary
87
from numpy.typing import ArrayLike
98

9+
from .utilities import scalar_objective
10+
1011
__all__ = [
1112
"value_and_grad",
13+
"grad",
1214
]
1315

1416

1517
@unary_to_nary
16-
def value_and_grad(
17-
fun: Callable, x: ArrayLike, *, has_aux: bool = False
18-
) -> tuple[tuple[float, ArrayLike], Any]:
19-
"""Returns a function that returns both value and gradient.
20-
21-
This function wraps and extends autograd's 'value_and_grad' function by adding
22-
support for auxiliary data.
18+
def grad(fun: Callable, x: ArrayLike, *, has_aux: bool = False) -> Callable:
19+
"""Returns a function that computes the gradient of `fun` with respect to `x`.
2320
2421
Parameters
2522
----------
2623
fun : Callable
27-
The function to differentiate. Should take a single argument and return
28-
a scalar value, or a tuple where the first element is a scalar value if has_aux is True.
24+
The function to differentiate. Should return a scalar value, or a tuple of
25+
(scalar_value, auxiliary_data) if `has_aux` is True.
2926
x : ArrayLike
30-
The point at which to evaluate the function and its gradient.
27+
The point at which to evaluate the gradient.
3128
has_aux : bool = False
32-
If True, the function returns auxiliary data as the second element of a tuple.
29+
If True, `fun` returns auxiliary data as the second element of a tuple.
3330
3431
Returns
3532
-------
36-
tuple[tuple[float, ArrayLike], Any]
37-
A tuple containing:
38-
- A tuple with the function value (float) and its gradient (ArrayLike)
39-
- The auxiliary data returned by the function (if has_aux is True)
33+
Callable
34+
A function that takes the same arguments as `fun` and returns its gradient at `x`.
35+
"""
36+
wrapped_fun = scalar_objective(fun, has_aux=has_aux)
37+
vjp, result = make_vjp(lambda x: atuple(wrapped_fun(x)) if has_aux else wrapped_fun(x), x)
4038

41-
Raises
42-
------
43-
TypeError
44-
If the function does not return a scalar value.
39+
if has_aux:
40+
ans, aux = result
41+
return vjp((vspace(ans).ones(), None)), aux
42+
ans = result
43+
return vjp(vspace(ans).ones())
4544

46-
Notes
47-
-----
48-
This function uses autograd for automatic differentiation. If the function
49-
does not return auxiliary data (has_aux is False), it delegates to autograd's
50-
value_and_grad function. The main extension is the support for auxiliary data
51-
when has_aux is True.
52-
"""
53-
if not has_aux:
54-
return value_and_grad_ag(fun)(x)
5545

56-
vjp, (ans, aux) = make_vjp(lambda x: atuple(fun(x)), x)
46+
@unary_to_nary
47+
def value_and_grad(fun: Callable, x: ArrayLike, *, has_aux: bool = False) -> Callable:
48+
"""Returns a function that computes both the value and gradient of `fun` with respect to `x`.
49+
50+
Parameters
51+
----------
52+
fun : Callable
53+
The function to differentiate. Should return a scalar value, or a tuple of
54+
(scalar_value, auxiliary_data) if `has_aux` is True.
55+
x : ArrayLike
56+
The point at which to evaluate the function and its gradient.
57+
has_aux : bool = False
58+
If True, `fun` returns auxiliary data as the second element of a tuple.
5759
58-
if not vspace(ans).size == 1:
59-
raise TypeError(
60-
"value_and_grad only applies to real scalar-output "
61-
"functions. Try jacobian, elementwise_grad or "
62-
"holomorphic_grad."
63-
)
60+
Returns
61+
-------
62+
Callable
63+
A function that takes the same arguments as `fun` and returns its value and gradient at `x`.
64+
"""
65+
wrapped_fun = scalar_objective(fun, has_aux=has_aux)
66+
vjp, result = make_vjp(lambda x: atuple(wrapped_fun(x)) if has_aux else wrapped_fun(x), x)
6467

65-
return (ans, vjp((vspace(ans).ones(), None))), aux
68+
if has_aux:
69+
ans, aux = result
70+
return (ans, vjp((vspace(ans).ones(), None))), aux
71+
ans = result
72+
return ans, vjp(vspace(ans).ones())

tidy3d/plugins/autograd/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def least_squares(
670670
>>> initial_guess = (0.0, 0.0)
671671
>>> params = least_squares(linear_model, x_data, y_data, initial_guess)
672672
>>> print(params)
673-
array([2.0, -3.0])
673+
[ 2. -3.]
674674
"""
675675
params = np.array(initial_guess, dtype="f8")
676676
jac = jacobian(lambda params: func(x, *params))

0 commit comments

Comments
 (0)