Skip to content

Commit 00cce22

Browse files
committed
feat[adjoint]: Add conductivity gradient for CustomMedium
Introduces support for computing gradients with respect to the `conductivity` field in `CustomMedium`. To achieve this, the derivative computation logic was generalized: - The `_derivative_field_cmp` method is refactored to accept a `component` parameter ('real', 'imag', 'complex') to compute the VJP for different parts of the complex permittivity. - For the 'imag' component, the derivative is scaled by `1 / (omega * epsilon_0)` to convert from derivative w.r.t. complex permittivity to derivative w.r.t. conductivity. - The `_compute_derivatives` dispatcher now handles the new `"conductivity"` parameter path. - A new autograd test is added to validate the conductivity gradient calculation on a `CustomMedium` with constant permittivity.
1 parent 6e5dd73 commit 00cce22

File tree

3 files changed

+126
-31
lines changed

3 files changed

+126
-31
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111
- Add support for `np.unwrap` in `tidy3d.plugins.autograd`.
12+
- Support for gradients with respect to the `conductivity` of a `CustomMedium`.
1213

1314
### Fixed
1415
- Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window.
@@ -100,8 +101,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
100101

101102
### Fixed
102103
- Fixed `reverse` property of `td.Scene.plot_structures_property()` to also reverse the colorbar.
103-
104-
### Fixed
105104
- Fixed bug in surface gradient computation where fields, instead of gradients, were being summed in frequency.
106105

107106
## [2.8.2] - 2025-04-09

tests/test_components/test_autograd.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -315,30 +315,42 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
315315
eps_arr = 1.01 + 0.5 * (anp.tanh(matrix @ params).reshape(DA_SHAPE) + 1)
316316

317317
nx, ny, nz = eps_arr.shape
318+
da_coords = {
319+
"x": np.linspace(-0.5, 0.5, nx),
320+
"y": np.linspace(-0.5, 0.5, ny),
321+
"z": np.linspace(-0.5, 0.5, nz),
322+
}
318323

319324
custom_med = td.Structure(
320325
geometry=box,
321326
medium=td.CustomMedium(
322327
permittivity=td.SpatialDataArray(
323328
eps_arr,
324-
coords={
325-
"x": np.linspace(-0.5, 0.5, nx),
326-
"y": np.linspace(-0.5, 0.5, ny),
327-
"z": np.linspace(-0.5, 0.5, nz),
328-
},
329+
coords=da_coords,
330+
),
331+
),
332+
)
333+
334+
# custom medium with variable permittivity and conductivity data
335+
conductivity_arr = 0.01 * (anp.tanh(matrix @ params).reshape(DA_SHAPE) + 1)
336+
custom_med_with_conductivity = td.Structure(
337+
geometry=box,
338+
medium=td.CustomMedium(
339+
permittivity=td.SpatialDataArray(
340+
eps_arr,
341+
coords=da_coords,
342+
),
343+
conductivity=td.SpatialDataArray(
344+
conductivity_arr,
345+
coords=da_coords,
329346
),
330347
),
331348
)
332349

333350
# custom medium with vector valued permittivity data
334351
eps_ii = td.ScalarFieldDataArray(
335352
eps_arr.reshape(nx, ny, nz, 1),
336-
coords={
337-
"x": np.linspace(-0.5, 0.5, nx),
338-
"y": np.linspace(-0.5, 0.5, ny),
339-
"z": np.linspace(-0.5, 0.5, nz),
340-
"f": [td.C_0],
341-
},
353+
coords=da_coords | {"f": [td.C_0]},
342354
)
343355

344356
custom_med_vec = td.Structure(
@@ -480,6 +492,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
480492
"center_list": center_list,
481493
"size_element": size_element,
482494
"custom_med": custom_med,
495+
"custom_med_with_conductivity": custom_med_with_conductivity,
483496
"custom_med_vec": custom_med_vec,
484497
"polyslab": polyslab,
485498
"polyslab_dispersive": polyslab_dispersive,
@@ -577,6 +590,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None:
577590
"center_list",
578591
"size_element",
579592
"custom_med",
593+
"custom_med_with_conductivity",
580594
"custom_med_vec",
581595
"polyslab",
582596
"complex_polyslab",
@@ -1661,7 +1675,7 @@ def J(eps):
16611675
monkeypatch.setattr(
16621676
td.CustomPoleResidue,
16631677
"_derivative_field_cmp",
1664-
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps,
1678+
lambda self, E_der_map, spatial_data, dim, freqs, component="real": dJ_deps,
16651679
)
16661680

16671681
import importlib
@@ -2351,3 +2365,52 @@ def objective(x):
23512365

23522366
with pytest.raises(ValueError):
23532367
g = ag.grad(objective)(1.0)
2368+
2369+
2370+
def test_custom_medium_conductivity_only_gradient(rng, use_emulated_run, tmp_path):
2371+
"""Test conductivity gradients for CustomMedium with constant permittivity."""
2372+
2373+
monitor, postprocess = make_monitors()["field_point"]
2374+
2375+
def objective(params):
2376+
"""Objective function testing only gonductivity gradient (constant permittivity)."""
2377+
len_arr = np.prod(DA_SHAPE)
2378+
matrix = rng.random((len_arr, N_PARAMS))
2379+
2380+
# constant permittivity
2381+
eps_arr = np.ones(DA_SHAPE) * 2.0
2382+
2383+
# variable conductivity
2384+
conductivity_arr = 0.05 * (anp.tanh(3 * matrix @ params).reshape(DA_SHAPE) + 1)
2385+
2386+
nx, ny, nz = DA_SHAPE
2387+
coords = {
2388+
"x": np.linspace(-0.5, 0.5, nx),
2389+
"y": np.linspace(-0.5, 0.5, ny),
2390+
"z": np.linspace(-0.5, 0.5, nz),
2391+
}
2392+
2393+
custom_med_struct = td.Structure(
2394+
geometry=td.Box(center=(0, 0, 0), size=(1, 1, 1)),
2395+
medium=td.CustomMedium(
2396+
permittivity=td.SpatialDataArray(eps_arr, coords=coords),
2397+
conductivity=td.SpatialDataArray(conductivity_arr, coords=coords),
2398+
),
2399+
)
2400+
2401+
sim = SIM_BASE.updated_copy(
2402+
structures=[custom_med_struct],
2403+
monitors=[monitor],
2404+
)
2405+
2406+
data = run(
2407+
sim,
2408+
path=str(tmp_path / "sim_test.hdf5"),
2409+
task_name="conductivity_only_grad_test",
2410+
verbose=False,
2411+
)
2412+
return postprocess(data, data[monitor.name])
2413+
2414+
val, grad = ag.value_and_grad(objective)(params0)
2415+
2416+
assert anp.all(grad != 0.0), "some gradients are 0 for conductivity-only test"

tidy3d/components/medium.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,14 +1663,16 @@ def _not_loaded(field):
16631663
def _derivative_field_cmp(
16641664
self,
16651665
E_der_map: ElectromagneticFieldDataset,
1666-
eps_data: PermittivityDataset,
1666+
spatial_data: PermittivityDataset,
16671667
dim: str,
16681668
freqs: NDArray,
16691669
) -> np.ndarray:
1670-
coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1}
1671-
dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp}
1670+
coords_interp = {key: val for key, val in spatial_data.coords.items() if len(val) > 1}
1671+
dims_sum = {dim for dim in spatial_data.coords.keys() if dim not in coords_interp}
16721672

1673-
eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"]
1673+
eps_coordinate_shape = [
1674+
len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz"
1675+
]
16741676

16751677
# compute sizes along each of the interpolation dimensions
16761678
sizes_list = []
@@ -2835,14 +2837,27 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
28352837
vjps = {}
28362838

28372839
for field_path in derivative_info.paths:
2838-
if field_path == ("permittivity",):
2840+
if field_path[0] == "permittivity":
28392841
vjp_array = 0.0
28402842
for dim in "xyz":
28412843
vjp_array += self._derivative_field_cmp(
28422844
E_der_map=derivative_info.E_der_map,
2843-
eps_data=self.permittivity,
2845+
spatial_data=self.permittivity,
28442846
dim=dim,
28452847
freqs=np.atleast_1d(derivative_info.frequency),
2848+
component="real",
2849+
)
2850+
vjps[field_path] = vjp_array
2851+
2852+
elif field_path[0] == "conductivity":
2853+
vjp_array = 0.0
2854+
for dim in "xyz":
2855+
vjp_array += self._derivative_field_cmp(
2856+
E_der_map=derivative_info.E_der_map,
2857+
spatial_data=self.conductivity,
2858+
dim=dim,
2859+
freqs=np.atleast_1d(derivative_info.frequency),
2860+
component="imag",
28462861
)
28472862
vjps[field_path] = vjp_array
28482863

@@ -2851,11 +2866,11 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
28512866
dim = key[-1]
28522867
vjps[field_path] = self._derivative_field_cmp(
28532868
E_der_map=derivative_info.E_der_map,
2854-
eps_data=self.eps_dataset.field_components[key],
2869+
spatial_data=self.eps_dataset.field_components[key],
28552870
dim=dim,
28562871
freqs=np.atleast_1d(derivative_info.frequency),
2872+
component="complex",
28572873
)
2858-
28592874
else:
28602875
raise NotImplementedError(
28612876
f"No derivative defined for 'CustomMedium' field: {field_path}."
@@ -2866,15 +2881,18 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
28662881
def _derivative_field_cmp(
28672882
self,
28682883
E_der_map: ElectromagneticFieldDataset,
2869-
eps_data: PermittivityDataset,
2884+
spatial_data: PermittivityDataset,
28702885
dim: str,
28712886
freqs: NDArray,
2887+
component: str = "real",
28722888
) -> np.ndarray:
2873-
"""Compute derivative with respect to the ``dim`` components within the custom medium."""
2874-
coords_interp = {key: eps_data.coords[key] for key in "xyz"}
2889+
"""Compute the derivative with respect to a material property component."""
2890+
coords_interp = {key: spatial_data.coords[key] for key in "xyz"}
28752891
coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1}
28762892

2877-
eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"]
2893+
eps_coordinate_shape = [
2894+
len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz"
2895+
]
28782896

28792897
E_der_dim_interp = E_der_map[f"E{dim}"].sel(f=freqs)
28802898

@@ -2912,10 +2930,25 @@ def _derivative_field_cmp(
29122930
# if sizes_list is empty, then reduce() fails
29132931
d_vol = np.array(1.0)
29142932

2915-
# TODO: probably this could be more robust. eg if the DataArray has weird edge cases
2916-
E_der_dim_interp = (
2917-
E_der_dim_interp.interp(**coords_interp, assume_sorted=True).fillna(0.0).real.sum("f")
2918-
)
2933+
E_der_dim_interp_complex = E_der_dim_interp.interp(
2934+
**coords_interp, assume_sorted=True
2935+
).fillna(0.0)
2936+
2937+
if component == "imag":
2938+
# convert from derivative w.r.t. complex permittivity to derivative w.r.t. conductivity
2939+
# dJ/d(sigma) = Im[dJ/d(eps_complex)] / (omega * epsilon_0)
2940+
omegas = 2 * np.pi * freqs
2941+
E_der_dim_interp = E_der_dim_interp_complex.imag
2942+
for i, omega in enumerate(omegas):
2943+
E_der_dim_interp.loc[{"f": freqs[i]}] = E_der_dim_interp.sel(f=freqs[i]) / (
2944+
omega * EPSILON_0
2945+
)
2946+
E_der_dim_interp = E_der_dim_interp.sum("f")
2947+
elif component == "complex":
2948+
# for complex permittivity in eps_dataset, return the full complex derivative
2949+
E_der_dim_interp = E_der_dim_interp_complex.sum("f")
2950+
else:
2951+
E_der_dim_interp = E_der_dim_interp_complex.real.sum("f")
29192952

29202953
try:
29212954
E_der_dim_interp = E_der_dim_interp * d_vol.reshape(E_der_dim_interp.shape)
@@ -3909,7 +3942,7 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
39093942
for dim in "xyz":
39103943
dJ_deps += self._derivative_field_cmp(
39113944
E_der_map=derivative_info.E_der_map,
3912-
eps_data=self.eps_inf,
3945+
spatial_data=self.eps_inf,
39133946
dim=dim,
39143947
freqs=np.atleast_1d(derivative_info.frequency),
39153948
)

0 commit comments

Comments
 (0)