Skip to content

Commit b1aaf05

Browse files
Gregory Robertsyaugenst-flex
authored andcommitted
feature[adjoint]: adjust adjoint source fwidth to decay before zero frequency when possible
1 parent 5b7c11c commit b1aaf05

File tree

3 files changed

+206
-5
lines changed

3 files changed

+206
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- Internal adjoint helper methods are now prefixed with an underscore to separate them from the public API.
2424
- Drop the dependency on `gdspy`, which has been unmaintained for over two years. Interfaces previously relying on `gdspy` now use its maintained successor, `gdstk`, with equivalent functionality.
2525
- Small (around 1e-4) numerical precision improvements in EME solver.
26+
- Adjoint source frequency width is adjusted to decay sufficiently before zero frequency when possible to improve accuracy of simulation normalization when using custom current sources.
2627

2728
## [2.8.4] - 2025-05-15
2829

tests/test_components/test_autograd.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,147 @@ def objective(args):
12251225
ag.grad(objective)(params0)
12261226

12271227

1228+
def test_adjoint_src_width():
1229+
"""Test the adjoint source width for single sources decays by f=0."""
1230+
1231+
f0 = td.C_0 / 1.55
1232+
fwidth = f0
1233+
1234+
fwidths = f0 * np.linspace(0.1, 1.0, 5)
1235+
1236+
adj_srcs = [
1237+
td.PointDipole(
1238+
center=(0, 0, 0),
1239+
source_time=td.GaussianPulse(freq0=f0, fwidth=fwidth),
1240+
polarization="Ex",
1241+
)
1242+
for fwidth in fwidths
1243+
]
1244+
1245+
adj_srcs_fwidth = td.SimulationData._adjoint_src_width_single(adj_srcs)
1246+
1247+
for src in adj_srcs_fwidth:
1248+
assert np.isclose(
1249+
(src.source_time.freq0 - f0) / f0, 0.0
1250+
), "f0 of adjoint source should be centered on original f0"
1251+
1252+
check_fwidth = (
1253+
src.source_time.freq0
1254+
- td.components.data.sim_data.NUM_ADJOINT_FWIDTH_TO_ZERO * src.source_time.fwidth
1255+
) / src.source_time.freq0
1256+
1257+
assert np.isclose(check_fwidth, 0.0) or (
1258+
check_fwidth > 0.0
1259+
), "fwidth of adjoint source should decay sufficiently before f=0"
1260+
1261+
1262+
def test_broadband_adjoint_src_width():
1263+
"""Test the broadband adjoint source handling for choosing fwidth."""
1264+
1265+
# Test the case where we have a custom current source and a wide adjoint source width that overlaps with zero.
1266+
# In this case, we want to issue a warning to the user about the adjoint accuracy of this setup.
1267+
f0_high = td.C_0 / 1.55
1268+
f0_low = 0.1 * f0_high
1269+
1270+
f0_adj_all = [f0_low, f0_high]
1271+
1272+
fwidth = 0.1 * f0_high
1273+
1274+
adj_srcs = []
1275+
x = np.array([0.0])
1276+
y = np.array([0.0])
1277+
z = np.array([0.0])
1278+
for f0 in f0_adj_all:
1279+
f = np.array([f0])
1280+
1281+
coords = dict(x=x, y=y, z=z, f=f)
1282+
1283+
dataset = td.FieldDataset(Ex=td.ScalarFieldDataArray(np.ones((1, 1, 1, 1)), coords=coords))
1284+
1285+
adj_srcs.append(
1286+
td.CustomCurrentSource(
1287+
center=(0, 0, 0),
1288+
size=(0, 0, 0),
1289+
source_time=td.GaussianPulse(freq0=f0, fwidth=fwidth),
1290+
current_dataset=dataset,
1291+
)
1292+
)
1293+
1294+
EXPECTED_WARNING_MSG_PIECE = (
1295+
"Adjoint source generated with a frequency spectrum that extends to or overlaps with 0 Hz"
1296+
)
1297+
with AssertLogLevel("WARNING", contains_str=EXPECTED_WARNING_MSG_PIECE):
1298+
broadband_f0, broadband_fwidth = td.SimulationData._adjoint_src_width_broadband(adj_srcs)
1299+
1300+
f0_expected = 0.5 * (np.max(f0_adj_all) + np.min(f0_adj_all))
1301+
1302+
fwidth_expected = (
1303+
f0_expected - np.min(f0_adj_all)
1304+
) / td.components.data.sim_data.NUM_ADJOINT_FWIDTH_TO_FMIN
1305+
1306+
assert np.isclose(
1307+
(f0_expected - broadband_f0) / f0_expected, 0.0
1308+
), "Expected freq0 not matching for broadband source"
1309+
assert np.isclose(
1310+
(fwidth_expected - broadband_fwidth) / fwidth_expected, 0.0
1311+
), "Expected fwidth not matching for broadband source"
1312+
1313+
# Test the case where we need a wider pulse to cover all the adjoint frequencies than we would otherwise choose for
1314+
# each individual adjoint source
1315+
f0_broadband = np.linspace(f0_low, f0_high, 10)
1316+
fwidth_broadband = 0.1 * np.mean(f0_broadband)
1317+
1318+
adj_srcs = [
1319+
td.PointDipole(
1320+
center=(0, 0, 0),
1321+
source_time=td.GaussianPulse(freq0=f0, fwidth=fwidth_broadband),
1322+
polarization="Ex",
1323+
)
1324+
for f0 in f0_broadband
1325+
]
1326+
1327+
broadband_f0, broadband_fwidth = td.SimulationData._adjoint_src_width_broadband(adj_srcs)
1328+
1329+
f0_expected = 0.5 * (np.max(f0_broadband) + np.min(f0_broadband))
1330+
fwidth_expected = (
1331+
f0_expected - np.min(f0_broadband)
1332+
) / td.components.data.sim_data.NUM_ADJOINT_FWIDTH_TO_FMIN
1333+
1334+
assert np.isclose(
1335+
(f0_expected - broadband_f0) / f0_expected, 0.0
1336+
), "Expected freq0 not matching for broadband source"
1337+
assert np.isclose(
1338+
(fwidth_expected - broadband_fwidth) / fwidth_expected, 0.0
1339+
), "Expected fwidth not matching for broadband source"
1340+
1341+
# Test the case where we have a narrow set of frequencies for the adjoint sources and so we can
1342+
# choose a wider overall source than is needed for covering those frequencies. This larger pulse width
1343+
# in frequency will shorten the time pulse.
1344+
f0_broadband = np.linspace(0.95 * f0_high, 1.05 * f0_high, 10)
1345+
fwidth_broadband = 0.1 * np.mean(f0_broadband)
1346+
1347+
adj_srcs = [
1348+
td.PointDipole(
1349+
center=(0, 0, 0),
1350+
source_time=td.GaussianPulse(freq0=f0, fwidth=fwidth_broadband),
1351+
polarization="Ex",
1352+
)
1353+
for f0 in f0_broadband
1354+
]
1355+
1356+
broadband_f0, broadband_fwidth = td.SimulationData._adjoint_src_width_broadband(adj_srcs)
1357+
1358+
f0_expected = 0.5 * (np.max(f0_broadband) + np.min(f0_broadband))
1359+
fwidth_expected = f0_expected / td.components.data.sim_data.NUM_ADJOINT_FWIDTH_TO_ZERO
1360+
1361+
assert np.isclose(
1362+
(f0_expected - broadband_f0) / f0_expected, 0.0
1363+
), "Expected freq0 not matching for broadband source"
1364+
assert np.isclose(
1365+
(fwidth_expected - broadband_fwidth) / fwidth_expected, 0.0
1366+
), "Expected fwidth not matching for broadband source"
1367+
1368+
12281369
@pytest.mark.parametrize("colocate", [True, False])
12291370
@pytest.mark.parametrize("objtype", ["flux", "intensity"])
12301371
def test_interp_objectives(use_emulated_run, colocate, objtype):

tidy3d/components/data/sim_data.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..file_util import replace_values
2323
from ..monitor import Monitor
2424
from ..simulation import Simulation
25+
from ..source.current import CustomCurrentSource
2526
from ..source.time import GaussianPulse
2627
from ..source.utils import SourceType
2728
from ..structure import Structure
@@ -43,6 +44,12 @@
4344
# residuals below this are considered good fits for broadband adjoint source creation
4445
RESIDUAL_CUTOFF_ADJOINT = 1e-6
4546

47+
# for adjoint source, the minimum number of FWIDTH between the center frequency and zero
48+
NUM_ADJOINT_FWIDTH_TO_ZERO = 3
49+
# for broadband adjoint source, the minimum number of FWIDTH to reach the lowest frequency
50+
# that is covered by the broadband pulse
51+
NUM_ADJOINT_FWIDTH_TO_FMIN = 0.5
52+
4653

4754
class AdjointSourceInfo(Tidy3dBaseModel):
4855
"""Stores information about the adjoint sources to pass to autograd pipeline."""
@@ -1123,31 +1130,49 @@ def _fwidth_adj(self) -> float:
11231130
normalize_index_fwd = self.simulation.normalize_index or 0
11241131
return self.simulation.sources[normalize_index_fwd].source_time.fwidth
11251132

1133+
@staticmethod
1134+
def _adjoint_src_width_single(adj_srcs: list[SourceType]) -> list[SourceType]:
1135+
"""Ensure the adjoint source sufficiently decays before zero frequency."""
1136+
adj_srcs_process_fwidth = []
1137+
for adj_src in adj_srcs:
1138+
source_time = adj_src.source_time
1139+
freq0 = source_time.freq0
1140+
1141+
fwidth = np.minimum(freq0 / NUM_ADJOINT_FWIDTH_TO_ZERO, source_time.fwidth)
1142+
1143+
adj_srcs_process_fwidth.append(
1144+
adj_src.updated_copy(source_time=source_time.updated_copy(fwidth=fwidth))
1145+
)
1146+
1147+
return adj_srcs_process_fwidth
1148+
11261149
def _process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSourceInfo]:
11271150
"""Compute list of final sources along with a post run normalization for adj fields."""
11281151
# dictionary mapping hash of sources with same freq dependence to list of time-dependencies
11291152
hashes_to_sources = defaultdict(None)
11301153
hashes_to_src_times = defaultdict(list)
11311154

1155+
adj_srcs_process_fwidth = self._adjoint_src_width_single(adj_srcs)
1156+
11321157
tmp_src_time = GaussianPulse(freq0=C_0, fwidth=inf)
1133-
for src in adj_srcs:
1158+
for src in adj_srcs_process_fwidth:
11341159
tmp_src = src.updated_copy(source_time=tmp_src_time)
11351160
tmp_src_hash = tmp_src._hash_self()
11361161
hashes_to_sources[tmp_src_hash] = src
11371162
hashes_to_src_times[tmp_src_hash].append(src.source_time)
11381163

11391164
# Group sources by frequency or port, whichever gives fewer groups
11401165
num_ports = len(hashes_to_src_times)
1141-
num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs})
1166+
num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs_process_fwidth})
11421167

11431168
log.info(f"Found {num_ports} spatial ports and {num_unique_freqs} unique frequencies.")
11441169

11451170
adjoint_infos = []
11461171
if num_unique_freqs <= num_ports:
11471172
log.info("Grouping adjoint sources by frequency.")
1148-
unique_freqs = {src.source_time.freq0 for src in adj_srcs}
1173+
unique_freqs = {src.source_time.freq0 for src in adj_srcs_process_fwidth}
11491174
for freq0 in unique_freqs:
1150-
group = [src for src in adj_srcs if src.source_time.freq0 == freq0]
1175+
group = [src for src in adj_srcs_process_fwidth if src.source_time.freq0 == freq0]
11511176
post_norm = xr.DataArray(data=np.array([1 + 0j]), coords={"f": [freq0]})
11521177
adjoint_infos.append(
11531178
AdjointSourceInfo(sources=group, post_norm=post_norm, normalize_sim=True)
@@ -1184,14 +1209,48 @@ def _process_adjoint_sources_broadband(
11841209

11851210
return [src_broadband], post_norm_amps
11861211

1212+
@staticmethod
1213+
def _adjoint_src_width_broadband(adj_srcs: list[SourceType]) -> float:
1214+
"""Find the adjoint source fwidth that sufficiently covers all adjoint frequencies."""
1215+
1216+
adj_srcs_f0 = [adj_src.source_time.freq0 for adj_src in adj_srcs]
1217+
middle_f0 = 0.5 * (np.max(adj_srcs_f0) + np.min(adj_srcs_f0))
1218+
min_f0 = np.min(adj_srcs_f0)
1219+
1220+
# width of source to sufficiently decay by zero frequency
1221+
decay_by_f0_fwidth = middle_f0 / NUM_ADJOINT_FWIDTH_TO_ZERO
1222+
# width of source to sufficiently cover all adjoint frequencies
1223+
fwidth_to_min_f0 = (middle_f0 - min_f0) / NUM_ADJOINT_FWIDTH_TO_FMIN
1224+
1225+
# log warning if the adjoint pulse width is not sufficiently decayed by zero frequency
1226+
# which may cause some issues in the adjoint accuracy when using field sources
1227+
if (fwidth_to_min_f0 > decay_by_f0_fwidth) and isinstance(adj_srcs[0], CustomCurrentSource):
1228+
log.warning(
1229+
"Adjoint source generated with a frequency spectrum that extends to or overlaps with 0 Hz. "
1230+
"This can introduce errors into the gradient computation."
1231+
)
1232+
1233+
print(f"source widths: {decay_by_f0_fwidth}, {fwidth_to_min_f0}")
1234+
1235+
# Choose a wider pulse width in frequency especially when the min/max frequencies
1236+
# for the broadband pulse might be very close together
1237+
adj_src_fwidth = np.maximum(decay_by_f0_fwidth, fwidth_to_min_f0)
1238+
1239+
return middle_f0, adj_src_fwidth
1240+
11871241
def _make_broadband_source(self, adj_srcs: list[SourceType]) -> SourceType:
11881242
"""Make a broadband source for a set of adjoint sources."""
11891243

1244+
adj_src_f0, adj_src_fwidth = self._adjoint_src_width_broadband(adj_srcs)
1245+
11901246
source_index = self.simulation.normalize_index or 0
1247+
11911248
src_time_base = self.simulation.sources[source_index].source_time.updated_copy(
11921249
amplitude=1.0, phase=0.0
11931250
)
1194-
src_broadband = adj_srcs[0].updated_copy(source_time=src_time_base)
1251+
src_broadband = adj_srcs[0].updated_copy(
1252+
source_time=src_time_base.updated_copy(freq0=adj_src_f0, fwidth=adj_src_fwidth)
1253+
)
11951254

11961255
return src_broadband
11971256

0 commit comments

Comments
 (0)