Skip to content

Commit 35469b0

Browse files
committed
maint[autograd]: warn instead of error if autograd tracers enter center/size of simulation, monitor, or source
1 parent 04e212a commit 35469b0

File tree

6 files changed

+55
-2
lines changed

6 files changed

+55
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
### Changed
11+
- Supplying autograd-traced values to geometric fields (`center`, `size`) of simulations, monitors, and sources now logs a warning and falls back to the static value instead of erroring.
1112
- Attempting to differentiate server-side field projections now raises a clear error instead of silently failing.
1213

1314
## [2.8.3] - 2025-04-24

tests/test_components/test_autograd.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,3 +2167,21 @@ def objective(args):
21672167
# model is called without a frequency
21682168
with AssertLogLevel("INFO"):
21692169
grad = ag.grad(objective)(params0)
2170+
2171+
2172+
def test_sim_traced_center_size(use_emulated_run):
2173+
fn_dict = get_functions(args[0][0], args[0][1])
2174+
make_sim = fn_dict["sim"]
2175+
postprocess = fn_dict["postprocess"]
2176+
base_sim = make_sim(params0)
2177+
2178+
def objective(center, size):
2179+
sim = base_sim.updated_copy(center=center, size=size)
2180+
sim_data = run_emulated(sim, task_name="adjoint_test")
2181+
return postprocess(sim_data)
2182+
2183+
with AssertLogLevel("WARNING", contains_str="autograd tracer"):
2184+
grad = ag.grad(objective, argnum=0)(base_sim.center, base_sim.size)
2185+
2186+
with AssertLogLevel("WARNING", contains_str="autograd tracer"):
2187+
grad = ag.grad(objective, argnum=1)(base_sim.center, base_sim.size)

tidy3d/components/base_sim/monitor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..base import cached_property
1010
from ..geometry.base import Box
1111
from ..types import ArrayFloat1D, Axis, Numpy
12+
from ..validators import _warn_unsupported_traced_argument
1213
from ..viz import PlotParams, plot_params_monitor
1314

1415

@@ -22,6 +23,9 @@ class AbstractMonitor(Box, ABC):
2223
min_length=1,
2324
)
2425

26+
_warn_traced_center = _warn_unsupported_traced_argument("center")
27+
_warn_traced_size = _warn_unsupported_traced_argument("size")
28+
2529
@cached_property
2630
def plot_params(self) -> PlotParams:
2731
"""Default parameters for plotting a Monitor object."""

tidy3d/components/base_sim/simulation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from ..scene import Scene
1818
from ..structure import Structure
1919
from ..types import TYPE_TAG_STR, Ax, Axis, Bound, LengthUnit, Symmetry
20-
from ..validators import assert_objects_in_sim_bounds, assert_unique_names
20+
from ..validators import (
21+
_warn_unsupported_traced_argument,
22+
assert_objects_in_sim_bounds,
23+
assert_unique_names,
24+
)
2125
from ..viz import (
2226
PlotParams,
2327
add_ax_if_none,
@@ -137,6 +141,9 @@ def _update_simulation(cls, values):
137141
_monitors_in_bounds = assert_objects_in_sim_bounds("monitors", strict_inequality=True)
138142
_structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False)
139143

144+
_warn_traced_center = _warn_unsupported_traced_argument("center")
145+
_warn_traced_size = _warn_unsupported_traced_argument("size")
146+
140147
@pd.validator("structures", always=True)
141148
@skip_if_fields_missing(["size", "center"])
142149
def _structures_not_at_edges(cls, val, values):

tidy3d/components/source/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..base_sim.source import AbstractSource
1212
from ..geometry.base import Box
1313
from ..types import TYPE_TAG_STR, Ax
14-
from ..validators import _assert_min_freq
14+
from ..validators import _assert_min_freq, _warn_unsupported_traced_argument
1515
from ..viz import (
1616
ARROW_ALPHA,
1717
ARROW_COLOR_POLARIZATION,
@@ -58,6 +58,9 @@ def _pol_vector(self) -> Tuple[float, float, float]:
5858
"""Returns a vector indicating the source polarization for arrow plotting, if not None."""
5959
return None
6060

61+
_warn_traced_center = _warn_unsupported_traced_argument("center")
62+
_warn_traced_size = _warn_unsupported_traced_argument("size")
63+
6164
@pydantic.validator("source_time", always=True)
6265
def _freqs_lower_bound(cls, val):
6366
"""Raise validation error if central frequency is too low."""

tidy3d/components/validators.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import numpy as np
44
import pydantic.v1 as pydantic
5+
from autograd.tracer import isbox
56

67
from ..exceptions import SetupError, ValidationError
78
from ..log import log
9+
from .autograd.utils import get_static
810
from .base import DATA_ARRAY_MAP, skip_if_fields_missing
911
from .data.dataset import Dataset, FieldDataset
1012
from .geometry.base import Box
@@ -464,3 +466,21 @@ def validate_mode_plane_radius(mode_spec: ModeSpec, plane: Box, msg_prefix: str
464466
f"{msg_prefix} bend radius is smaller than half the mode plane size "
465467
"along the radial axis, which can produce wrong results."
466468
)
469+
470+
471+
def _warn_unsupported_traced_argument(name: str):
472+
@pydantic.validator(name, always=True, allow_reuse=True)
473+
def _warn_traced_arg(cls, val, values):
474+
if isbox(val):
475+
log.warning(
476+
f"Field '{name}' of '{cls.__name__}' received an autograd tracer "
477+
f"(i.e., a value being tracked for automatic differentiation). "
478+
f"Automatic differentiation through this field is unsupported, "
479+
f"so the tracer has been converted to its static value. "
480+
f"If you want to avoid this warning, you manually unbox the value "
481+
f"using the 'autograd.tracer.getval' function before passing it to Tidy3D."
482+
)
483+
return get_static(val)
484+
return val
485+
486+
return _warn_traced_arg

0 commit comments

Comments
 (0)