Skip to content

Commit cd5d2df

Browse files
Gregory Robertsdaquinteroflex
authored andcommitted
fix(adjoint): fix frequency selection in adjoint postprocessing frequency batching to make sure all data aligns in frequency dimension before passing to geometry or medium for derivative computation
1 parent e87005e commit cd5d2df

File tree

3 files changed

+51
-28
lines changed

3 files changed

+51
-28
lines changed

CHANGELOG.md

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10-
### Added
11-
12-
### Changed
13-
14-
### Fixed
15-
- Giving opposite boundaries different names no longer causes a symmetry validator failure.
16-
- Fixed issue with parameters in `InverseDesignResult` sometimes being outside of the valid parameter range.
17-
- Disallow `EMEFieldMonitor` in EME simulations with `EMELengthSweep`.
18-
19-
2010
## [2.9.0] - 2025-08-04
2111

2212
### Added
@@ -94,6 +84,10 @@ with fewer layers than recommended.
9484
- Giving opposite boundaries different names no longer causes a symmetry validator failure.
9585
- Fixed issue with parameters in `InverseDesignResult` sometimes being outside of the valid parameter range.
9686
- Fixed performance regression for multi-frequency adjoint calculations.
87+
- Disallow `EMEFieldMonitor` in EME simulations with `EMELengthSweep`.
88+
- Fixed bug in adjoint postprocessing frequency batching that was causing gradients to be zero or incorrect. The error was surfacing when selecting a subset of the monitor frequencies in the objective function.
89+
90+
9791

9892
## [2.8.5] - 2025-07-07
9993

tests/test_components/test_autograd_mode_polyslab_numerical.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,14 +359,34 @@ def test_finite_difference_mode_data_polyslab(
359359
# due to a multifrequency objective function.
360360
monitor_top_weights = rng.random(NUM_MODE_MONITOR_FREQUENCIES)
361361
monitor_bottom_weights = rng.random(NUM_MODE_MONITOR_FREQUENCIES)
362+
frequency_selection_mask = np.arange(0, NUM_MODE_MONITOR_FREQUENCIES)
363+
364+
# sometimes, test what happens when we only use one of the frequencies from the mode monitors
365+
# to catch handling of different frequencies being present in the forward and adjoint monitors
366+
if rng.random() > 0.5:
367+
frequency_selection_mask = rng.integers(1, NUM_MODE_MONITOR_FREQUENCIES)
368+
monitor_top_weights = monitor_top_weights[frequency_selection_mask]
369+
monitor_bottom_weights = monitor_bottom_weights[frequency_selection_mask]
362370

363371
def eval_fn(sim_data):
364372
return np.sum(
365373
monitor_top_weights
366-
* np.abs(sim_data["monitor_mode_top"].amps.sel(direction="+").values) ** 2
374+
* np.abs(
375+
sim_data["monitor_mode_top"]
376+
.amps.sel(direction="+")
377+
.isel(f=frequency_selection_mask)
378+
.data
379+
)
380+
** 2
367381
) + np.sum(
368382
monitor_bottom_weights
369-
* np.abs(sim_data["monitor_mode_bottom"].amps.sel(direction="+").values) ** 2
383+
* np.abs(
384+
sim_data["monitor_mode_bottom"]
385+
.amps.sel(direction="+")
386+
.isel(f=frequency_selection_mask)
387+
.data
388+
)
389+
** 2
370390
)
371391

372392
polyslab_height_um = POLYSLAB_HEIGHT_WVL * adj_wvl_um

tidy3d/web/api/autograd/autograd.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -989,22 +989,25 @@ def _compute_eps_array(medium, frequencies):
989989
return DataArray(data=np.array(eps_data), dims=("f",), coords={"f": frequencies})
990990

991991

992-
def _slice_field_data(field_data: dict, freq_slice: slice) -> dict:
992+
def _slice_field_data(
993+
field_data: dict,
994+
freqs: np.ndarray,
995+
) -> dict:
993996
"""Slice field data dictionary along frequency dimension.
994997
995998
Parameters
996999
----------
9971000
field_data : dict
9981001
Dictionary of field components.
999-
freq_slice : slice
1000-
Frequency slice to apply.
1002+
freqs : np.ndarray
1003+
Frequencies to select.
10011004
10021005
Returns
10031006
-------
10041007
dict
10051008
Sliced field data dictionary.
10061009
"""
1007-
return {k: v.isel(f=freq_slice) for k, v in field_data.items()}
1010+
return {k: v.sel(f=freqs) for k, v in field_data.items()}
10081011

10091012

10101013
def postprocess_adj(
@@ -1119,26 +1122,32 @@ def postprocess_adj(
11191122
chunk_end = min(chunk_start + freq_chunk_size, n_freqs)
11201123
freq_slice = slice(chunk_start, chunk_end)
11211124

1125+
select_adjoint_freqs = adjoint_frequencies[freq_slice]
1126+
11221127
# slice field data for current chunk
1123-
E_der_map_chunk = _slice_field_data(E_der_map.field_components, freq_slice)
1124-
D_der_map_chunk = _slice_field_data(D_der_map.field_components, freq_slice)
1125-
E_fwd_chunk = _slice_field_data(E_fwd.field_components, freq_slice)
1126-
E_adj_chunk = _slice_field_data(E_adj.field_components, freq_slice)
1127-
D_fwd_chunk = _slice_field_data(D_fwd.field_components, freq_slice)
1128-
D_adj_chunk = _slice_field_data(D_adj.field_components, freq_slice)
1129-
eps_data_chunk = _slice_field_data(eps_fwd.field_components, freq_slice)
1128+
E_der_map_chunk = _slice_field_data(E_der_map.field_components, select_adjoint_freqs)
1129+
D_der_map_chunk = _slice_field_data(D_der_map.field_components, select_adjoint_freqs)
1130+
E_fwd_chunk = _slice_field_data(E_fwd.field_components, select_adjoint_freqs)
1131+
E_adj_chunk = _slice_field_data(E_adj.field_components, select_adjoint_freqs)
1132+
D_fwd_chunk = _slice_field_data(D_fwd.field_components, select_adjoint_freqs)
1133+
D_adj_chunk = _slice_field_data(D_adj.field_components, select_adjoint_freqs)
1134+
eps_data_chunk = _slice_field_data(eps_fwd.field_components, select_adjoint_freqs)
11301135

11311136
# slice epsilon arrays
1132-
eps_in_chunk = eps_in.isel(f=freq_slice)
1133-
eps_out_chunk = eps_out.isel(f=freq_slice)
1137+
eps_in_chunk = eps_in.sel(f=select_adjoint_freqs)
1138+
eps_out_chunk = eps_out.sel(f=select_adjoint_freqs)
11341139
eps_background_chunk = (
1135-
eps_background.isel(f=freq_slice) if eps_background is not None else None
1140+
eps_background.sel(f=select_adjoint_freqs) if eps_background is not None else None
11361141
)
11371142
eps_no_structure_chunk = (
1138-
eps_no_structure.isel(f=freq_slice) if eps_no_structure is not None else None
1143+
eps_no_structure.sel(f=select_adjoint_freqs)
1144+
if eps_no_structure is not None
1145+
else None
11391146
)
11401147
eps_inf_structure_chunk = (
1141-
eps_inf_structure.isel(f=freq_slice) if eps_inf_structure is not None else None
1148+
eps_inf_structure.sel(f=select_adjoint_freqs)
1149+
if eps_inf_structure is not None
1150+
else None
11421151
)
11431152

11441153
# create derivative info with sliced data

0 commit comments

Comments
 (0)