Skip to content

Commit 856f237

Browse files
committed
feat: autograd support for rotated Box
1 parent 54a74b4 commit 856f237

File tree

5 files changed

+566
-116
lines changed

5 files changed

+566
-116
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- A property `interior_angle` in `PolySlab` that stores angles formed inside polygon by two adjacent edges.
1414
- `eps_component` argument in `td.Simulation.plot_eps()` to optionally select a specific permittivity component to plot (eg. `"xx"`).
1515
- Monitor `AuxFieldTimeMonitor` for aux fields like the free carrier density in `TwoPhotonAbsorption`.
16+
- Gradient computation for rotated boxes in Transformed.
1617

1718
### Fixed
1819
- Compatibility with `xarray>=2025.03`.

tests/test_components/test_autograd.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,3 +2150,123 @@ def objective(args):
21502150
# model is called without a frequency
21512151
with AssertLogLevel("INFO"):
21522152
grad = ag.grad(objective)(params0)
2153+
2154+
2155+
def make_sim_rotation(center: tuple, size: tuple, angle: float, axis: int):
2156+
wavelength = 1.5
2157+
L = 10 * wavelength
2158+
freq0 = td.C_0 / wavelength
2159+
buffer = 1.0 * wavelength
2160+
2161+
# Source
2162+
src = td.PointDipole(
2163+
center=(-L / 2 + buffer, 0, 0),
2164+
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
2165+
polarization="Ez",
2166+
)
2167+
# Monitor
2168+
mnt = td.FieldMonitor(
2169+
center=(
2170+
+L / 2 - buffer,
2171+
0.5 * buffer,
2172+
0.5 * buffer,
2173+
),
2174+
size=(0.0, 0.0, 0.0),
2175+
freqs=[freq0],
2176+
name="point",
2177+
)
2178+
# The box geometry
2179+
base_box = td.Box(center=center, size=size)
2180+
if angle is not None:
2181+
base_box = base_box.rotated(angle, axis)
2182+
2183+
scatterer = td.Structure(
2184+
geometry=base_box,
2185+
medium=td.Medium(permittivity=2.0),
2186+
)
2187+
2188+
sim = td.Simulation(
2189+
size=(L, L, L),
2190+
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
2191+
structures=[scatterer],
2192+
sources=[src],
2193+
monitors=[mnt],
2194+
run_time=120 / freq0,
2195+
)
2196+
return sim
2197+
2198+
2199+
def objective_fn(center, size, angle, axis):
2200+
sim = make_sim_rotation(center, size, angle, axis)
2201+
sim_data = web.run(sim, task_name="emulated_rot_test", local_gradient=True, verbose=False)
2202+
return anp.sum(sim_data.get_intensity("point").values)
2203+
2204+
2205+
def get_grad(center, size, angle, axis):
2206+
def wrapped(c, s):
2207+
return objective_fn(c, s, angle, axis)
2208+
2209+
val, (grad_c, grad_s) = ag.value_and_grad(wrapped, argnum=(0, 1))(center, size)
2210+
return val, grad_c, grad_s
2211+
2212+
2213+
@pytest.mark.numerical
2214+
@pytest.mark.parametrize(
2215+
"angle_deg, axis",
2216+
[
2217+
(0.0, 1),
2218+
(180.0, 1),
2219+
(90.0, 1),
2220+
(270.0, 1),
2221+
],
2222+
)
2223+
def test_box_rotation_gradients(use_emulated_run, angle_deg, axis):
2224+
center0 = (0.0, 0.0, 0.0)
2225+
size0 = (2.0, 2.0, 2.0)
2226+
2227+
angle_rad = np.deg2rad(angle_deg)
2228+
val, grad_c, grad_s = get_grad(center0, size0, angle=None, axis=None)
2229+
npx, npy, npz = grad_c
2230+
sSx, sSy, sSz = grad_s
2231+
2232+
assert not np.allclose(grad_c, 0.0), "center gradient is all zero."
2233+
assert not np.allclose(grad_s, 0.0), "size gradient is all zero."
2234+
2235+
if angle_deg == 180.0:
2236+
# rotating 180° about y => (x,z) become negated, y stays same
2237+
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
2238+
rSx, rSy, rSz = grad_s_ref
2239+
rx, ry, rz = grad_c_ref
2240+
2241+
assert np.allclose(grad_c[0], -grad_c_ref[0], atol=1e-6), "center_x sign mismatch"
2242+
assert np.allclose(grad_c[1], grad_c_ref[1], atol=1e-6), "center_y mismatch"
2243+
assert np.allclose(grad_c[2], -grad_c_ref[2], atol=1e-6), "center_z sign mismatch"
2244+
assert np.allclose(grad_s, grad_s_ref, atol=1e-6), "size grads changed unexpectedly"
2245+
2246+
elif angle_deg == 90.0:
2247+
# rotating 90° about y => new x= old z, new z=- old x, y stays same
2248+
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
2249+
rSx, rSy, rSz = grad_s_ref
2250+
rx, ry, rz = grad_c_ref
2251+
2252+
assert np.allclose(npx, rz, atol=1e-6), "center_x != old center_z"
2253+
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
2254+
assert np.allclose(npz, -rx, atol=1e-6), "center_z != - old center_x"
2255+
2256+
assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
2257+
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
2258+
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"
2259+
2260+
elif angle_deg == 270.0:
2261+
# rotating 270° about y => new x= - old z, new z= old x, y stays same
2262+
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
2263+
rSx, rSy, rSz = grad_s_ref
2264+
rx, ry, rz = grad_c_ref
2265+
2266+
assert np.allclose(npx, -rz, atol=1e-6), "center_x != - old center_z"
2267+
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
2268+
assert np.allclose(npz, rx, atol=1e-6), "center_z != old center_x"
2269+
2270+
assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
2271+
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
2272+
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import atexit
2+
import os
3+
from collections import defaultdict
4+
5+
import autograd
6+
import autograd.numpy as anp
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
import pytest
10+
import tidy3d as td
11+
import tidy3d.web as web
12+
13+
SAVE_RESULTS = False
14+
PLOT_RESULTS = False
15+
RESULTS_DIR = "./fd_ad_results"
16+
results_collector = defaultdict(list)
17+
18+
wavelength = 1.5
19+
freq0 = td.C_0 / wavelength
20+
L = 10 * wavelength
21+
buffer = 1.0 * wavelength
22+
run_time = 120 / freq0
23+
24+
25+
SCENARIOS = [
26+
{
27+
"name": "(1) normal",
28+
"has_background": False,
29+
"background_eps": 3.0,
30+
"box_eps": 2.0,
31+
"rotation_deg": None,
32+
"rotation_axis": None,
33+
},
34+
{
35+
"name": "(2) perm=1.5",
36+
"has_background": True,
37+
"background_eps": 1.5,
38+
"box_eps": 2.0,
39+
"rotation_deg": None,
40+
"rotation_axis": None,
41+
},
42+
{
43+
"name": "(3) rotation=0 deg about z",
44+
"has_background": False,
45+
"background_eps": 1.5,
46+
"box_eps": 2.0,
47+
"rotation_deg": 0.0,
48+
"rotation_axis": 2,
49+
},
50+
{
51+
"name": "(4) rotation=90 deg about z",
52+
"has_background": False,
53+
"background_eps": 1.5,
54+
"box_eps": 2.0,
55+
"rotation_deg": 90.0,
56+
"rotation_axis": 2,
57+
},
58+
{
59+
"name": "(5) rotation=45 deg about y",
60+
"has_background": False,
61+
"background_eps": 1.5,
62+
"box_eps": 2.0,
63+
"rotation_deg": 45.0,
64+
"rotation_axis": 1,
65+
},
66+
{
67+
"name": "(6) rotation=45 deg about x",
68+
"has_background": False,
69+
"background_eps": 1.5,
70+
"box_eps": 2.0,
71+
"rotation_deg": 45.0,
72+
"rotation_axis": 0,
73+
},
74+
{
75+
"name": "(7) rotation=45 deg about z",
76+
"has_background": False,
77+
"background_eps": 1.5,
78+
"box_eps": 2.0,
79+
"rotation_deg": 45.0,
80+
"rotation_axis": 2,
81+
},
82+
]
83+
84+
PARAM_LABELS = ["center_x", "center_x", "center_y", "center_z", "size_x", "size_y", "size_z"]
85+
86+
87+
def make_sim(center: tuple, size: tuple, scenario: dict):
88+
source = td.PointDipole(
89+
center=(-L / 2 + buffer, 0.0, 0.0),
90+
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
91+
polarization="Ez",
92+
)
93+
94+
monitor = td.FieldMonitor(
95+
center=(+L / 2 - buffer, 0.5 * buffer, 0.5 * buffer),
96+
size=(0, 0, 0),
97+
freqs=[freq0],
98+
name="point_out",
99+
)
100+
101+
structures = []
102+
if scenario["has_background"]:
103+
back_box = td.Box(center=(0.0, 0.0, 0.0), size=(4.0, 1.6, 1.6))
104+
background_box = td.Structure(
105+
geometry=back_box,
106+
medium=td.Medium(permittivity=scenario["background_eps"]),
107+
)
108+
structures.append(background_box)
109+
110+
scatter_box = td.Box(center=center, size=size)
111+
112+
if scenario["rotation_deg"] is not None:
113+
angle_rad = np.deg2rad(scenario["rotation_deg"])
114+
rotated_geom = scatter_box.rotated(angle_rad, scenario["rotation_axis"])
115+
else:
116+
rotated_geom = scatter_box
117+
118+
scatter_struct = td.Structure(
119+
geometry=rotated_geom,
120+
medium=td.Medium(permittivity=scenario["box_eps"]),
121+
)
122+
structures.append(scatter_struct)
123+
124+
sim = td.Simulation(
125+
size=(L, L, L),
126+
run_time=run_time,
127+
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
128+
sources=[source],
129+
monitors=[monitor],
130+
structures=structures,
131+
)
132+
return sim
133+
134+
135+
def objective_fn(center, size, scenario):
136+
sim = make_sim(center, size, scenario)
137+
sim_data = web.run(sim, task_name="autograd_vs_fd_scenario", local_gradient=True, verbose=False)
138+
return anp.sum(sim_data.get_intensity("point_out").values)
139+
140+
141+
def fd_vs_ad_param(center, size, scenario, param_label, delta=1e-3):
142+
val_and_grad_fn = autograd.value_and_grad(
143+
lambda c, s: objective_fn(c, s, scenario), argnum=(0, 1)
144+
)
145+
_, (grad_center, grad_size) = val_and_grad_fn(center, size)
146+
147+
param_map = {
148+
"center_x": (0, "center"),
149+
"center_y": (1, "center"),
150+
"center_z": (2, "center"),
151+
"size_x": (0, "size"),
152+
"size_y": (1, "size"),
153+
"size_z": (2, "size"),
154+
}
155+
idx, which = param_map[param_label]
156+
if which == "center":
157+
ad_val = grad_center[idx]
158+
else:
159+
ad_val = grad_size[idx]
160+
161+
center_arr = np.array(center, dtype=float)
162+
size_arr = np.array(size, dtype=float)
163+
164+
if which == "center":
165+
cplus = center_arr.copy()
166+
cminus = center_arr.copy()
167+
cplus[idx] += delta
168+
cminus[idx] -= delta
169+
p_plus = objective_fn(tuple(cplus), tuple(size_arr), scenario)
170+
p_minus = objective_fn(tuple(cminus), tuple(size_arr), scenario)
171+
else:
172+
splus = size_arr.copy()
173+
sminus = size_arr.copy()
174+
splus[idx] += delta
175+
sminus[idx] -= delta
176+
p_plus = objective_fn(tuple(center_arr), tuple(splus), scenario)
177+
p_minus = objective_fn(tuple(center_arr), tuple(sminus), scenario)
178+
179+
fd_val = (p_plus - p_minus) / (2.0 * delta)
180+
return fd_val, ad_val, p_plus, p_minus
181+
182+
183+
@pytest.mark.numerical
184+
@pytest.mark.parametrize("scenario", SCENARIOS, ids=[s["name"] for s in SCENARIOS])
185+
@pytest.mark.parametrize(
186+
"param_label", ["center_x", "center_y", "center_z", "size_x", "size_y", "size_z"]
187+
)
188+
def test_autograd_vs_fd_scenarios(scenario, param_label):
189+
center0 = (0.0, 0.0, 0.0)
190+
size0 = (2.0, 2.0, 2.0)
191+
delta = 0.03
192+
193+
fd_val, ad_val, p_plus, p_minus = fd_vs_ad_param(center0, size0, scenario, param_label, delta)
194+
195+
assert np.isfinite(fd_val), f"FD derivative is not finite for param={param_label}"
196+
assert np.isfinite(ad_val), f"AD derivative is not finite for param={param_label}"
197+
198+
denom = max(abs(fd_val), 1e-12)
199+
rel_diff = abs(fd_val - ad_val) / denom
200+
assert rel_diff < 0.3, f"Autograd vs FD mismatch: param={param_label}, diff={rel_diff:.1%}"
201+
202+
results_collector[param_label].append((scenario["name"], rel_diff))
203+
204+
if SAVE_RESULTS:
205+
os.makedirs(RESULTS_DIR, exist_ok=True)
206+
results_data = {
207+
"scenario_name": scenario["name"],
208+
"param_label": param_label,
209+
"delta": float(delta),
210+
"fd_val": float(fd_val),
211+
"ad_val": float(ad_val),
212+
"p_plus": float(p_plus),
213+
"p_minus": float(p_minus),
214+
"rel_diff": float(rel_diff),
215+
}
216+
filename_npy = f"fd_ad_{scenario['name'].replace(' ', '_')}_{param_label}.npy"
217+
np.save(os.path.join(RESULTS_DIR, filename_npy), results_data)
218+
219+
220+
def finalize_plotting():
221+
if not PLOT_RESULTS:
222+
return
223+
224+
os.makedirs(RESULTS_DIR, exist_ok=True)
225+
226+
for param_label in PARAM_LABELS:
227+
scenario_data = results_collector[param_label]
228+
if not scenario_data:
229+
continue
230+
scenario_names = [sd[0] for sd in scenario_data]
231+
rel_diffs = [sd[1] for sd in scenario_data]
232+
233+
plt.figure(figsize=(6, 4))
234+
plt.bar(scenario_names, rel_diffs, color="blue")
235+
plt.xticks(rotation=45, ha="right")
236+
plt.title(f"Relative Error for param='{param_label}'\n(FD vs AD)")
237+
plt.ylabel("Relative Error")
238+
plt.tight_layout()
239+
240+
filename_png = f"rel_error_{param_label.replace('_', '-')}.png"
241+
plt.savefig(os.path.join(RESULTS_DIR, filename_png))
242+
plt.close()
243+
print(f"Saved bar chart => {filename_png}")
244+
245+
246+
atexit.register(finalize_plotting)

0 commit comments

Comments
 (0)