Skip to content

Commit abaecd9

Browse files
committed
feat(adjoint): Add conductivity gradient for CustomMedium
Introduces support for computing gradients with respect to the `conductivity` field in `CustomMedium`. To achieve this, the derivative computation logic was generalized: - The `_derivative_field_cmp` method is refactored to accept a `component` parameter ('real', 'imag', 'complex') to compute the VJP for different parts of the complex permittivity. - For the 'imag' component, the derivative is scaled by `1 / (omega * epsilon_0)` to convert from derivative w.r.t. complex permittivity to derivative w.r.t. conductivity. - The `_compute_derivatives` dispatcher now handles the new `"conductivity"` parameter path. - A new autograd test is added to validate the conductivity gradient calculation on a `CustomMedium` with constant permittivity.
1 parent 5e7ef85 commit abaecd9

File tree

3 files changed

+155
-29
lines changed

3 files changed

+155
-29
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- Added `InternalAbsorber` class for placing first-order absorbing boundary conditions on planes inside the simulation domain. Internal absorbers are automatically wrapped in a PEC frame with a backing PEC plate on the non-absorbing side.
1919
- Added `absorber` field (default: `True`) to `WavePort` for automatically placing an absorber behind the port.
2020
- Added `conjugated_dot_product` field in `ModeMonitor` (default: `True`) and `WavePort` (default: `False`) to allow selecting the conjugated or non-conjugated dot product for mode decomposition.
21+
- Support for gradients with respect to the `conductivity` of a `CustomMedium`.
2122

2223
### Changed
2324
- Validate mode solver object for large number of grid points on the modal plane.
2425
- Adaptive minimum spacing for `PolySlab` integration is now wavelength relative and a minimum discretization is set for computing gradients for cylinders.
2526
- The `TerminalComponentModeler` defaults to the pseudo wave definition of scattering parameters. The new field `s_param_def` can be used to switch between either pseudo or power wave definitions.
27+
- Add support for `np.unwrap` in `tidy3d.plugins.autograd`.
2628

2729
### Fixed
2830
- Fixed missing amplitude factor and handling of negative normal direction case when making adjoint sources from `DiffractionMonitor`.

tests/test_components/test_autograd.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,30 +319,42 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
319319
eps_arr = 1.01 + 0.5 * (anp.tanh(matrix @ params).reshape(DA_SHAPE) + 1)
320320

321321
nx, ny, nz = eps_arr.shape
322+
da_coords = {
323+
"x": np.linspace(-0.5, 0.5, nx),
324+
"y": np.linspace(-0.5, 0.5, ny),
325+
"z": np.linspace(-0.5, 0.5, nz),
326+
}
322327

323328
custom_med = td.Structure(
324329
geometry=box,
325330
medium=td.CustomMedium(
326331
permittivity=td.SpatialDataArray(
327332
eps_arr,
328-
coords={
329-
"x": np.linspace(-0.5, 0.5, nx),
330-
"y": np.linspace(-0.5, 0.5, ny),
331-
"z": np.linspace(-0.5, 0.5, nz),
332-
},
333+
coords=da_coords,
334+
),
335+
),
336+
)
337+
338+
# custom medium with variable permittivity and conductivity data
339+
conductivity_arr = 0.01 * (anp.tanh(matrix @ params).reshape(DA_SHAPE) + 1)
340+
custom_med_with_conductivity = td.Structure(
341+
geometry=box,
342+
medium=td.CustomMedium(
343+
permittivity=td.SpatialDataArray(
344+
eps_arr,
345+
coords=da_coords,
346+
),
347+
conductivity=td.SpatialDataArray(
348+
conductivity_arr,
349+
coords=da_coords,
333350
),
334351
),
335352
)
336353

337354
# custom medium with vector valued permittivity data
338355
eps_ii = td.ScalarFieldDataArray(
339356
eps_arr.reshape(nx, ny, nz, 1),
340-
coords={
341-
"x": np.linspace(-0.5, 0.5, nx),
342-
"y": np.linspace(-0.5, 0.5, ny),
343-
"z": np.linspace(-0.5, 0.5, nz),
344-
"f": [td.C_0],
345-
},
357+
coords=da_coords | {"f": [td.C_0]},
346358
)
347359

348360
custom_med_vec = td.Structure(
@@ -484,6 +496,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
484496
"center_list": center_list,
485497
"size_element": size_element,
486498
"custom_med": custom_med,
499+
"custom_med_with_conductivity": custom_med_with_conductivity,
487500
"custom_med_vec": custom_med_vec,
488501
"polyslab": polyslab,
489502
"polyslab_dispersive": polyslab_dispersive,
@@ -581,6 +594,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None:
581594
"center_list",
582595
"size_element",
583596
"custom_med",
597+
"custom_med_with_conductivity",
584598
"custom_med_vec",
585599
"polyslab",
586600
"complex_polyslab",
@@ -1738,7 +1752,7 @@ def J(eps):
17381752
monkeypatch.setattr(
17391753
td.CustomPoleResidue,
17401754
"_derivative_field_cmp",
1741-
lambda self, E_der_map, eps_data, dim: dJ_deps / 3.0,
1755+
lambda self, E_der_map, spatial_data, dim, freqs, component="real": dJ_deps / 3.0,
17421756
)
17431757

17441758
import importlib
@@ -2434,3 +2448,52 @@ def objective(x):
24342448

24352449
with pytest.raises(ValueError):
24362450
g = ag.grad(objective)(1.0)
2451+
2452+
2453+
def test_custom_medium_conductivity_only_gradient(rng, use_emulated_run, tmp_path):
2454+
"""Test conductivity gradients for CustomMedium with constant permittivity."""
2455+
2456+
monitor, postprocess = make_monitors()["field_point"]
2457+
2458+
def objective(params):
2459+
"""Objective function testing only conductivity gradient (constant permittivity)."""
2460+
len_arr = np.prod(DA_SHAPE)
2461+
matrix = rng.random((len_arr, N_PARAMS))
2462+
2463+
# constant permittivity
2464+
eps_arr = np.ones(DA_SHAPE) * 2.0
2465+
2466+
# variable conductivity
2467+
conductivity_arr = 0.05 * (anp.tanh(3 * matrix @ params).reshape(DA_SHAPE) + 1)
2468+
2469+
nx, ny, nz = DA_SHAPE
2470+
coords = {
2471+
"x": np.linspace(-0.5, 0.5, nx),
2472+
"y": np.linspace(-0.5, 0.5, ny),
2473+
"z": np.linspace(-0.5, 0.5, nz),
2474+
}
2475+
2476+
custom_med_struct = td.Structure(
2477+
geometry=td.Box(center=(0, 0, 0), size=(1, 1, 1)),
2478+
medium=td.CustomMedium(
2479+
permittivity=td.SpatialDataArray(eps_arr, coords=coords),
2480+
conductivity=td.SpatialDataArray(conductivity_arr, coords=coords),
2481+
),
2482+
)
2483+
2484+
sim = SIM_BASE.updated_copy(
2485+
structures=[custom_med_struct],
2486+
monitors=[monitor],
2487+
)
2488+
2489+
data = run(
2490+
sim,
2491+
path=str(tmp_path / "sim_test.hdf5"),
2492+
task_name="conductivity_only_grad_test",
2493+
verbose=False,
2494+
)
2495+
return postprocess(data, data[monitor.name])
2496+
2497+
val, grad = ag.value_and_grad(objective)(params0)
2498+
2499+
assert anp.all(grad != 0.0), "some gradients are 0 for conductivity-only test"

tidy3d/components/medium.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,13 +1682,15 @@ def _not_loaded(field):
16821682
def _derivative_field_cmp(
16831683
self,
16841684
E_der_map: ElectromagneticFieldDataset,
1685-
eps_data: PermittivityDataset,
1685+
spatial_data: PermittivityDataset,
16861686
dim: str,
16871687
) -> np.ndarray:
1688-
coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1}
1689-
dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp}
1688+
coords_interp = {key: val for key, val in spatial_data.coords.items() if len(val) > 1}
1689+
dims_sum = {dim for dim in spatial_data.coords.keys() if dim not in coords_interp}
16901690

1691-
eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"]
1691+
eps_coordinate_shape = [
1692+
len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz"
1693+
]
16921694

16931695
# compute sizes along each of the interpolation dimensions
16941696
sizes_list = []
@@ -2898,13 +2900,27 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
28982900
vjps = {}
28992901

29002902
for field_path in derivative_info.paths:
2901-
if field_path == ("permittivity",):
2903+
if field_path[0] == "permittivity":
29022904
vjp_array = 0.0
29032905
for dim in "xyz":
29042906
vjp_array += self._derivative_field_cmp(
29052907
E_der_map=derivative_info.E_der_map,
2906-
eps_data=self.permittivity,
2908+
spatial_data=self.permittivity,
29072909
dim=dim,
2910+
freqs=np.atleast_1d(derivative_info.frequencies),
2911+
component="real",
2912+
)
2913+
vjps[field_path] = vjp_array
2914+
2915+
elif field_path[0] == "conductivity":
2916+
vjp_array = 0.0
2917+
for dim in "xyz":
2918+
vjp_array += self._derivative_field_cmp(
2919+
E_der_map=derivative_info.E_der_map,
2920+
spatial_data=self.conductivity,
2921+
dim=dim,
2922+
freqs=np.atleast_1d(derivative_info.frequencies),
2923+
component="imag",
29082924
)
29092925
vjps[field_path] = vjp_array
29102926

@@ -2913,10 +2929,11 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
29132929
dim = key[-1]
29142930
vjps[field_path] = self._derivative_field_cmp(
29152931
E_der_map=derivative_info.E_der_map,
2916-
eps_data=self.eps_dataset.field_components[key],
2932+
spatial_data=self.eps_dataset.field_components[key],
29172933
dim=dim,
2934+
freqs=np.atleast_1d(derivative_info.frequencies),
2935+
component="complex",
29182936
)
2919-
29202937
else:
29212938
raise NotImplementedError(
29222939
f"No derivative defined for 'CustomMedium' field: {field_path}."
@@ -2927,14 +2944,18 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
29272944
def _derivative_field_cmp(
29282945
self,
29292946
E_der_map: ElectromagneticFieldDataset,
2930-
eps_data: PermittivityDataset,
2947+
spatial_data: CustomSpatialDataTypeAnnotated,
29312948
dim: str,
2949+
freqs: np.ndarray,
2950+
component: str = "real",
29322951
) -> np.ndarray:
2933-
"""Compute derivative with respect to the ``dim`` components within the custom medium."""
2934-
coords_interp = {key: eps_data.coords[key] for key in "xyz"}
2952+
"""Compute the derivative with respect to a material property component."""
2953+
coords_interp = {key: spatial_data.coords[key] for key in "xyz"}
29352954
coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1}
29362955

2937-
eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"]
2956+
eps_coordinate_shape = [
2957+
len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz"
2958+
]
29382959

29392960
E_der_dim_interp = E_der_map[f"E{dim}"]
29402961

@@ -2972,10 +2993,26 @@ def _derivative_field_cmp(
29722993
# if sizes_list is empty, then reduce() fails
29732994
d_vol = np.array(1.0)
29742995

2975-
# TODO: probably this could be more robust. eg if the DataArray has weird edge cases
2976-
E_der_dim_interp = (
2977-
E_der_dim_interp.interp(**coords_interp, assume_sorted=True).fillna(0.0).real.sum("f")
2978-
)
2996+
E_der_dim_interp_complex = E_der_dim_interp.interp(
2997+
**coords_interp, assume_sorted=True
2998+
).fillna(0.0)
2999+
3000+
if component == "imag":
3001+
# convert from derivative w.r.t. complex permittivity to derivative w.r.t. conductivity
3002+
E_der_dim_interp = E_der_dim_interp_complex.imag
3003+
# frequency-dependent scaling must be applied before summing over frequencies
3004+
for freq in freqs:
3005+
vjp_imag = Medium.eps_sigma_to_eps_complex(
3006+
eps_real=0, sigma=E_der_dim_interp.sel(f=freq), freq=freq
3007+
).imag
3008+
E_der_dim_interp.loc[{"f": freq}] = -vjp_imag
3009+
elif component == "complex":
3010+
# for complex permittivity in eps_dataset, return the full complex derivative
3011+
E_der_dim_interp = E_der_dim_interp_complex
3012+
else:
3013+
E_der_dim_interp = E_der_dim_interp_complex.real
3014+
3015+
E_der_dim_interp = E_der_dim_interp.sum("f")
29793016

29803017
try:
29813018
E_der_dim_interp = E_der_dim_interp * d_vol.reshape(E_der_dim_interp.shape)
@@ -3975,15 +4012,39 @@ def _sel_custom_data_inside(self, bounds: Bound):
39754012

39764013
return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced)
39774014

4015+
def _derivative_field_cmp(
4016+
self,
4017+
E_der_map: ElectromagneticFieldDataset,
4018+
spatial_data: CustomSpatialDataTypeAnnotated,
4019+
dim: str,
4020+
freqs=None,
4021+
component: str = "complex",
4022+
) -> np.ndarray:
4023+
"""Compatibility wrapper for derivative computation.
4024+
4025+
Accepts the extended signature used by other custom media (
4026+
e.g., `CustomMedium._derivative_field_cmp`) while delegating the actual
4027+
computation to the base implementation that only depends on
4028+
`E_der_map`, `spatial_data`, and `dim`.
4029+
4030+
Parameters `freqs` and `component` are ignored for this model since the
4031+
derivative is taken with respect to the complex permittivity directly.
4032+
"""
4033+
return super()._derivative_field_cmp(
4034+
E_der_map=E_der_map, spatial_data=spatial_data, dim=dim
4035+
)
4036+
39784037
def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
39794038
"""Compute adjoint derivatives by preparing array data and calling the static helper."""
39804039

39814040
dJ_deps_complex = 0.0
39824041
for dim in "xyz":
39834042
dJ_deps_complex += self._derivative_field_cmp(
39844043
E_der_map=derivative_info.E_der_map,
3985-
eps_data=self.eps_inf,
4044+
spatial_data=self.eps_inf,
39864045
dim=dim,
4046+
freqs=np.atleast_1d(derivative_info.frequencies),
4047+
component="complex",
39874048
)
39884049

39894050
poles_vals = [

0 commit comments

Comments
 (0)