Skip to content

Commit 76681a0

Browse files
Gregory Robertsgroberts-flex
authored andcommitted
fix[adjoint]: select proper shape for CustomMedium derivatives when eps_data has a frequency dimension with multiple entries
1 parent 7b1cc3a commit 76681a0

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818

1919
### Fixed
2020
- Fixed bug in broadband adjoint source creation when forward simulation had a pulse amplitude greater than 1 or a nonzero pulse phase.
21+
- Fixed shaping of `CustomMedium` gradients when permittivity data includes a frequency dimension with multiple entries.
2122

2223
### Changed
2324
- Relaxed bounds checking of path integrals during `WavePort` validation.
2425
- Internal adjoint helper methods are now prefixed with an underscore to separate them from the public API.
2526
- Drop the dependency on `gdspy`, which has been unmaintained for over two years. Interfaces previously relying on `gdspy` now use its maintained successor, `gdstk`, with equivalent functionality.
2627
- Small (around 1e-4) numerical precision improvements in EME solver.
2728
- Adjoint source frequency width is adjusted to decay sufficiently before zero frequency when possible to improve accuracy of simulation normalization when using custom current sources.
29+
- Change `VisualizationSpec` validator for checking validity of user specified colors to only issue a warning if matplotlib is not installed instead of an error.
2830

2931
## [2.8.4] - 2025-05-15
3032

tidy3d/components/medium.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,8 @@ def _derivative_field_cmp(
16721672
coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1}
16731673
dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp}
16741674

1675+
eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"]
1676+
16751677
# compute sizes along each of the interpolation dimensions
16761678
sizes_list = []
16771679
for _, coords in coords_interp.items():
@@ -1703,7 +1705,7 @@ def _derivative_field_cmp(
17031705
E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).sum("f")
17041706
)
17051707
vjp_array = np.array(E_der_dim_interp.values).astype(complex)
1706-
vjp_array = vjp_array.reshape(eps_data.shape)
1708+
vjp_array = vjp_array.reshape(eps_coordinate_shape)
17071709

17081710
# multiply by volume elements (if possible, being defensive here..)
17091711
try:
@@ -2871,10 +2873,11 @@ def _derivative_field_cmp(
28712873
freqs: NDArray,
28722874
) -> np.ndarray:
28732875
"""Compute derivative with respect to the ``dim`` components within the custom medium."""
2874-
28752876
coords_interp = {key: eps_data.coords[key] for key in "xyz"}
28762877
coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1}
28772878

2879+
eps_coordinate_shape = [len(eps_data.coords[dim]) for dim in eps_data.dims if dim in "xyz"]
2880+
28782881
E_der_dim_interp = E_der_map[f"E{dim}"].sel(f=freqs)
28792882

28802883
for dim_ in "xyz":
@@ -2928,7 +2931,7 @@ def _derivative_field_cmp(
29282931
"message and some information about your simulation setup and we will investigate. "
29292932
)
29302933
vjp_array = E_der_dim_interp.values
2931-
vjp_array = vjp_array.reshape(eps_data.shape)
2934+
vjp_array = vjp_array.reshape(eps_coordinate_shape)
29322935

29332936
return vjp_array
29342937

0 commit comments

Comments
 (0)