Skip to content

Commit 3b1a6be

Browse files
committed
feat: autograd support for rotated Box
1 parent 35469b0 commit 3b1a6be

File tree

5 files changed

+595
-132
lines changed

5 files changed

+595
-132
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Gradient computation for rotated boxes in Transformed.
12+
1013
### Changed
1114
- Supplying autograd-traced values to geometric fields (`center`, `size`) of simulations, monitors, and sources now logs a warning and falls back to the static value instead of erroring.
1215
- Attempting to differentiate server-side field projections now raises a clear error instead of silently failing.

tests/test_components/test_autograd.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,3 +2185,123 @@ def objective(center, size):
21852185

21862186
with AssertLogLevel("WARNING", contains_str="autograd tracer"):
21872187
grad = ag.grad(objective, argnum=1)(base_sim.center, base_sim.size)
2188+
2189+
2190+
def make_sim_rotation(center: tuple, size: tuple, angle: float, axis: int):
2191+
wavelength = 1.5
2192+
L = 10 * wavelength
2193+
freq0 = td.C_0 / wavelength
2194+
buffer = 1.0 * wavelength
2195+
2196+
# Source
2197+
src = td.PointDipole(
2198+
center=(-L / 2 + buffer, 0, 0),
2199+
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
2200+
polarization="Ez",
2201+
)
2202+
# Monitor
2203+
mnt = td.FieldMonitor(
2204+
center=(
2205+
+L / 2 - buffer,
2206+
0.5 * buffer,
2207+
0.5 * buffer,
2208+
),
2209+
size=(0.0, 0.0, 0.0),
2210+
freqs=[freq0],
2211+
name="point",
2212+
)
2213+
# The box geometry
2214+
base_box = td.Box(center=center, size=size)
2215+
if angle is not None:
2216+
base_box = base_box.rotated(angle, axis)
2217+
2218+
scatterer = td.Structure(
2219+
geometry=base_box,
2220+
medium=td.Medium(permittivity=2.0),
2221+
)
2222+
2223+
sim = td.Simulation(
2224+
size=(L, L, L),
2225+
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
2226+
structures=[scatterer],
2227+
sources=[src],
2228+
monitors=[mnt],
2229+
run_time=120 / freq0,
2230+
)
2231+
return sim
2232+
2233+
2234+
def objective_fn(center, size, angle, axis):
2235+
sim = make_sim_rotation(center, size, angle, axis)
2236+
sim_data = web.run(sim, task_name="emulated_rot_test", local_gradient=True, verbose=False)
2237+
return anp.sum(sim_data.get_intensity("point").values)
2238+
2239+
2240+
def get_grad(center, size, angle, axis):
2241+
def wrapped(c, s):
2242+
return objective_fn(c, s, angle, axis)
2243+
2244+
val, (grad_c, grad_s) = ag.value_and_grad(wrapped, argnum=(0, 1))(center, size)
2245+
return val, grad_c, grad_s
2246+
2247+
2248+
@pytest.mark.numerical
2249+
@pytest.mark.parametrize(
2250+
"angle_deg, axis",
2251+
[
2252+
(0.0, 1),
2253+
(180.0, 1),
2254+
(90.0, 1),
2255+
(270.0, 1),
2256+
],
2257+
)
2258+
def test_box_rotation_gradients(use_emulated_run, angle_deg, axis):
2259+
center0 = (0.0, 0.0, 0.0)
2260+
size0 = (2.0, 2.0, 2.0)
2261+
2262+
angle_rad = np.deg2rad(angle_deg)
2263+
val, grad_c, grad_s = get_grad(center0, size0, angle=None, axis=None)
2264+
npx, npy, npz = grad_c
2265+
sSx, sSy, sSz = grad_s
2266+
2267+
assert not np.allclose(grad_c, 0.0), "center gradient is all zero."
2268+
assert not np.allclose(grad_s, 0.0), "size gradient is all zero."
2269+
2270+
if angle_deg == 180.0:
2271+
# rotating 180° about y => (x,z) become negated, y stays same
2272+
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
2273+
rSx, rSy, rSz = grad_s_ref
2274+
rx, ry, rz = grad_c_ref
2275+
2276+
assert np.allclose(grad_c[0], -grad_c_ref[0], atol=1e-6), "center_x sign mismatch"
2277+
assert np.allclose(grad_c[1], grad_c_ref[1], atol=1e-6), "center_y mismatch"
2278+
assert np.allclose(grad_c[2], -grad_c_ref[2], atol=1e-6), "center_z sign mismatch"
2279+
assert np.allclose(grad_s, grad_s_ref, atol=1e-6), "size grads changed unexpectedly"
2280+
2281+
elif angle_deg == 90.0:
2282+
# rotating 90° about y => new x= old z, new z=- old x, y stays same
2283+
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
2284+
rSx, rSy, rSz = grad_s_ref
2285+
rx, ry, rz = grad_c_ref
2286+
2287+
assert np.allclose(npx, rz, atol=1e-6), "center_x != old center_z"
2288+
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
2289+
assert np.allclose(npz, -rx, atol=1e-6), "center_z != - old center_x"
2290+
2291+
assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
2292+
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
2293+
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"
2294+
2295+
elif angle_deg == 270.0:
2296+
# rotating 270° about y => new x= - old z, new z= old x, y stays same
2297+
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
2298+
rSx, rSy, rSz = grad_s_ref
2299+
rx, ry, rz = grad_c_ref
2300+
2301+
assert np.allclose(npx, -rz, atol=1e-6), "center_x != - old center_z"
2302+
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
2303+
assert np.allclose(npz, rx, atol=1e-6), "center_z != old center_x"
2304+
2305+
assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
2306+
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
2307+
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import atexit
2+
import os
3+
from collections import defaultdict
4+
5+
import autograd
6+
import autograd.numpy as anp
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
import pytest
10+
import tidy3d as td
11+
import tidy3d.web as web
12+
13+
SAVE_RESULTS = False
14+
PLOT_RESULTS = False
15+
RESULTS_DIR = "./fd_ad_results"
16+
results_collector = defaultdict(list)
17+
18+
wavelength = 1.5
19+
freq0 = td.C_0 / wavelength
20+
L = 10 * wavelength
21+
buffer = 1.0 * wavelength
22+
run_time = 120 / freq0
23+
24+
25+
SCENARIOS = [
26+
{
27+
"name": "(1) normal",
28+
"has_background": False,
29+
"background_eps": 3.0,
30+
"box_eps": 2.0,
31+
"rotation_deg": None,
32+
"rotation_axis": None,
33+
},
34+
{
35+
"name": "(2) perm=1.5",
36+
"has_background": True,
37+
"background_eps": 1.5,
38+
"box_eps": 2.0,
39+
"rotation_deg": None,
40+
"rotation_axis": None,
41+
},
42+
{
43+
"name": "(3) rotation=0 deg about z",
44+
"has_background": False,
45+
"background_eps": 1.5,
46+
"box_eps": 2.0,
47+
"rotation_deg": 0.0,
48+
"rotation_axis": 2,
49+
},
50+
{
51+
"name": "(4) rotation=90 deg about z",
52+
"has_background": False,
53+
"background_eps": 1.5,
54+
"box_eps": 2.0,
55+
"rotation_deg": 90.0,
56+
"rotation_axis": 2,
57+
},
58+
{
59+
"name": "(5) rotation=45 deg about y",
60+
"has_background": False,
61+
"background_eps": 1.5,
62+
"box_eps": 2.0,
63+
"rotation_deg": 45.0,
64+
"rotation_axis": 1,
65+
},
66+
{
67+
"name": "(6) rotation=45 deg about x",
68+
"has_background": False,
69+
"background_eps": 1.5,
70+
"box_eps": 2.0,
71+
"rotation_deg": 45.0,
72+
"rotation_axis": 0,
73+
},
74+
{
75+
"name": "(7) rotation=45 deg about z",
76+
"has_background": False,
77+
"background_eps": 1.5,
78+
"box_eps": 2.0,
79+
"rotation_deg": 45.0,
80+
"rotation_axis": 2,
81+
},
82+
]
83+
84+
PARAM_LABELS = ["center_x", "center_x", "center_y", "center_z", "size_x", "size_y", "size_z"]
85+
86+
87+
def make_sim(center: tuple, size: tuple, scenario: dict):
88+
source = td.PointDipole(
89+
center=(-L / 2 + buffer, 0.0, 0.0),
90+
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
91+
polarization="Ez",
92+
)
93+
94+
monitor = td.FieldMonitor(
95+
center=(+L / 2 - buffer, 0.5 * buffer, 0.5 * buffer),
96+
size=(0, 0, 0),
97+
freqs=[freq0],
98+
name="point_out",
99+
)
100+
101+
structures = []
102+
if scenario["has_background"]:
103+
back_box = td.Box(center=(0.0, 0.0, 0.0), size=(4.0, 1.6, 1.6))
104+
background_box = td.Structure(
105+
geometry=back_box,
106+
medium=td.Medium(permittivity=scenario["background_eps"]),
107+
)
108+
structures.append(background_box)
109+
110+
scatter_box = td.Box(center=center, size=size)
111+
112+
if scenario["rotation_deg"] is not None:
113+
angle_rad = np.deg2rad(scenario["rotation_deg"])
114+
rotated_geom = scatter_box.rotated(angle_rad, scenario["rotation_axis"])
115+
else:
116+
rotated_geom = scatter_box
117+
118+
scatter_struct = td.Structure(
119+
geometry=rotated_geom,
120+
medium=td.Medium(permittivity=scenario["box_eps"]),
121+
)
122+
structures.append(scatter_struct)
123+
124+
sim = td.Simulation(
125+
size=(L, L, L),
126+
run_time=run_time,
127+
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
128+
sources=[source],
129+
monitors=[monitor],
130+
structures=structures,
131+
)
132+
return sim
133+
134+
135+
def objective_fn(center, size, scenario):
136+
sim = make_sim(center, size, scenario)
137+
sim_data = web.run(sim, task_name="autograd_vs_fd_scenario", local_gradient=True, verbose=False)
138+
return anp.sum(sim_data.get_intensity("point_out").values)
139+
140+
141+
def fd_vs_ad_param(center, size, scenario, param_label, delta=1e-3):
142+
val_and_grad_fn = autograd.value_and_grad(
143+
lambda c, s: objective_fn(c, s, scenario), argnum=(0, 1)
144+
)
145+
_, (grad_center, grad_size) = val_and_grad_fn(center, size)
146+
147+
param_map = {
148+
"center_x": (0, "center"),
149+
"center_y": (1, "center"),
150+
"center_z": (2, "center"),
151+
"size_x": (0, "size"),
152+
"size_y": (1, "size"),
153+
"size_z": (2, "size"),
154+
}
155+
idx, which = param_map[param_label]
156+
if which == "center":
157+
ad_val = grad_center[idx]
158+
else:
159+
ad_val = grad_size[idx]
160+
161+
center_arr = np.array(center, dtype=float)
162+
size_arr = np.array(size, dtype=float)
163+
164+
if which == "center":
165+
cplus = center_arr.copy()
166+
cminus = center_arr.copy()
167+
cplus[idx] += delta
168+
cminus[idx] -= delta
169+
p_plus = objective_fn(tuple(cplus), tuple(size_arr), scenario)
170+
p_minus = objective_fn(tuple(cminus), tuple(size_arr), scenario)
171+
else:
172+
splus = size_arr.copy()
173+
sminus = size_arr.copy()
174+
splus[idx] += delta
175+
sminus[idx] -= delta
176+
p_plus = objective_fn(tuple(center_arr), tuple(splus), scenario)
177+
p_minus = objective_fn(tuple(center_arr), tuple(sminus), scenario)
178+
179+
fd_val = (p_plus - p_minus) / (2.0 * delta)
180+
return fd_val, ad_val, p_plus, p_minus
181+
182+
183+
@pytest.mark.numerical
184+
@pytest.mark.parametrize("scenario", SCENARIOS, ids=[s["name"] for s in SCENARIOS])
185+
@pytest.mark.parametrize(
186+
"param_label", ["center_x", "center_y", "center_z", "size_x", "size_y", "size_z"]
187+
)
188+
def test_autograd_vs_fd_scenarios(scenario, param_label):
189+
center0 = (0.0, 0.0, 0.0)
190+
size0 = (2.0, 2.0, 2.0)
191+
delta = 0.03
192+
193+
fd_val, ad_val, p_plus, p_minus = fd_vs_ad_param(center0, size0, scenario, param_label, delta)
194+
195+
assert np.isfinite(fd_val), f"FD derivative is not finite for param={param_label}"
196+
assert np.isfinite(ad_val), f"AD derivative is not finite for param={param_label}"
197+
198+
denom = max(abs(fd_val), 1e-12)
199+
rel_diff = abs(fd_val - ad_val) / denom
200+
assert rel_diff < 0.3, f"Autograd vs FD mismatch: param={param_label}, diff={rel_diff:.1%}"
201+
202+
results_collector[param_label].append((scenario["name"], rel_diff))
203+
204+
if SAVE_RESULTS:
205+
os.makedirs(RESULTS_DIR, exist_ok=True)
206+
results_data = {
207+
"scenario_name": scenario["name"],
208+
"param_label": param_label,
209+
"delta": float(delta),
210+
"fd_val": float(fd_val),
211+
"ad_val": float(ad_val),
212+
"p_plus": float(p_plus),
213+
"p_minus": float(p_minus),
214+
"rel_diff": float(rel_diff),
215+
}
216+
filename_npy = f"fd_ad_{scenario['name'].replace(' ', '_')}_{param_label}.npy"
217+
np.save(os.path.join(RESULTS_DIR, filename_npy), results_data)
218+
219+
220+
def finalize_plotting():
221+
if not PLOT_RESULTS:
222+
return
223+
224+
os.makedirs(RESULTS_DIR, exist_ok=True)
225+
226+
for param_label in PARAM_LABELS:
227+
scenario_data = results_collector[param_label]
228+
if not scenario_data:
229+
continue
230+
scenario_names = [sd[0] for sd in scenario_data]
231+
rel_diffs = [sd[1] for sd in scenario_data]
232+
233+
plt.figure(figsize=(6, 4))
234+
plt.bar(scenario_names, rel_diffs, color="blue")
235+
plt.xticks(rotation=45, ha="right")
236+
plt.title(f"Relative Error for param='{param_label}'\n(FD vs AD)")
237+
plt.ylabel("Relative Error")
238+
plt.tight_layout()
239+
240+
filename_png = f"rel_error_{param_label.replace('_', '-')}.png"
241+
plt.savefig(os.path.join(RESULTS_DIR, filename_png))
242+
plt.close()
243+
print(f"Saved bar chart => {filename_png}")
244+
245+
246+
atexit.register(finalize_plotting)

0 commit comments

Comments
 (0)