Skip to content

Commit 733c31c

Browse files
Gregory RobertsGregory Roberts
authored andcommitted
feat(adjoint): add numerical_structures hook to custom run functions for user-defined structure gradients
1 parent 70ddbbb commit 733c31c

File tree

9 files changed

+1642
-129
lines changed

9 files changed

+1642
-129
lines changed

changelog.d/3308.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Add ``numerical_structures`` hook to custom autograd run paths for user-defined structure creation and gradients in simulation and component-modeler workflows.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
5+
import numpy as np
6+
import xarray as xr
7+
8+
import tidy3d as td
9+
10+
11+
def compute_ring_vjp(
12+
parameters: np.ndarray,
13+
derivative_info,
14+
create_ring_fn: Callable[[np.ndarray], td.Structure],
15+
) -> dict[tuple[int], float]:
16+
"""Compute finite-difference VJP values for ring parameter paths."""
17+
max_frequency = np.max(derivative_info.frequencies)
18+
min_wvl = td.C_0 / max_frequency
19+
step_size = min_wvl / 20.0
20+
21+
update_kwargs = {"paths": [("permittivity",)], "deep": False}
22+
derivative_info_custom_medium = derivative_info.updated_copy(**update_kwargs)
23+
24+
params_np = np.array(parameters)
25+
26+
vjps = {}
27+
for path in derivative_info.paths:
28+
param_idx = path[0]
29+
params_up = params_np.copy()
30+
params_down = params_np.copy()
31+
params_up[param_idx] += step_size
32+
params_down[param_idx] -= step_size
33+
34+
ring_up = create_ring_fn(params_up)
35+
ring_down = create_ring_fn(params_down)
36+
37+
eps_up = derivative_info.updated_epsilon(ring_up.geometry)
38+
eps_down = derivative_info.updated_epsilon(ring_down.geometry)
39+
eps_grad = (eps_up - eps_down) / (2 * step_size)
40+
41+
custom_medium = td.CustomMedium(permittivity=xr.ones_like(eps_grad.isel(f=0, drop=True)))
42+
vjps_custom_medium = custom_medium._compute_derivatives(derivative_info_custom_medium)
43+
total_grad = np.real(np.sum(eps_grad.sum("f").data * vjps_custom_medium[("permittivity",)]))
44+
vjps[path] = total_grad
45+
46+
return vjps

0 commit comments

Comments
 (0)