|
22 | 22 | from ..file_util import replace_values
|
23 | 23 | from ..monitor import Monitor
|
24 | 24 | from ..simulation import Simulation
|
| 25 | +from ..source.current import CustomCurrentSource |
25 | 26 | from ..source.time import GaussianPulse
|
26 | 27 | from ..source.utils import SourceType
|
27 | 28 | from ..structure import Structure
|
|
43 | 44 | # residuals below this are considered good fits for broadband adjoint source creation
|
44 | 45 | RESIDUAL_CUTOFF_ADJOINT = 1e-6
|
45 | 46 |
|
| 47 | +# for adjoint source, the minimum number of FWIDTH between the center frequency and zero |
| 48 | +NUM_ADJOINT_FWIDTH_TO_ZERO = 3 |
| 49 | +# for broadband adjoint source, the minimum number of FWIDTH to reach the lowest frequency |
| 50 | +# that is covered by the broadband pulse |
| 51 | +NUM_ADJOINT_FWIDTH_TO_FMIN = 0.5 |
| 52 | + |
46 | 53 |
|
47 | 54 | class AdjointSourceInfo(Tidy3dBaseModel):
|
48 | 55 | """Stores information about the adjoint sources to pass to autograd pipeline."""
|
@@ -1123,31 +1130,49 @@ def _fwidth_adj(self) -> float:
|
1123 | 1130 | normalize_index_fwd = self.simulation.normalize_index or 0
|
1124 | 1131 | return self.simulation.sources[normalize_index_fwd].source_time.fwidth
|
1125 | 1132 |
|
| 1133 | + @staticmethod |
| 1134 | + def _adjoint_src_width_single(adj_srcs: list[SourceType]) -> list[SourceType]: |
| 1135 | + """Ensure the adjoint source sufficiently decays before zero frequency.""" |
| 1136 | + adj_srcs_process_fwidth = [] |
| 1137 | + for adj_src in adj_srcs: |
| 1138 | + source_time = adj_src.source_time |
| 1139 | + freq0 = source_time.freq0 |
| 1140 | + |
| 1141 | + fwidth = np.minimum(freq0 / NUM_ADJOINT_FWIDTH_TO_ZERO, source_time.fwidth) |
| 1142 | + |
| 1143 | + adj_srcs_process_fwidth.append( |
| 1144 | + adj_src.updated_copy(source_time=source_time.updated_copy(fwidth=fwidth)) |
| 1145 | + ) |
| 1146 | + |
| 1147 | + return adj_srcs_process_fwidth |
| 1148 | + |
1126 | 1149 | def _process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSourceInfo]:
|
1127 | 1150 | """Compute list of final sources along with a post run normalization for adj fields."""
|
1128 | 1151 | # dictionary mapping hash of sources with same freq dependence to list of time-dependencies
|
1129 | 1152 | hashes_to_sources = defaultdict(None)
|
1130 | 1153 | hashes_to_src_times = defaultdict(list)
|
1131 | 1154 |
|
| 1155 | + adj_srcs_process_fwidth = self._adjoint_src_width_single(adj_srcs) |
| 1156 | + |
1132 | 1157 | tmp_src_time = GaussianPulse(freq0=C_0, fwidth=inf)
|
1133 |
| - for src in adj_srcs: |
| 1158 | + for src in adj_srcs_process_fwidth: |
1134 | 1159 | tmp_src = src.updated_copy(source_time=tmp_src_time)
|
1135 | 1160 | tmp_src_hash = tmp_src._hash_self()
|
1136 | 1161 | hashes_to_sources[tmp_src_hash] = src
|
1137 | 1162 | hashes_to_src_times[tmp_src_hash].append(src.source_time)
|
1138 | 1163 |
|
1139 | 1164 | # Group sources by frequency or port, whichever gives fewer groups
|
1140 | 1165 | num_ports = len(hashes_to_src_times)
|
1141 |
| - num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs}) |
| 1166 | + num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs_process_fwidth}) |
1142 | 1167 |
|
1143 | 1168 | log.info(f"Found {num_ports} spatial ports and {num_unique_freqs} unique frequencies.")
|
1144 | 1169 |
|
1145 | 1170 | adjoint_infos = []
|
1146 | 1171 | if num_unique_freqs <= num_ports:
|
1147 | 1172 | log.info("Grouping adjoint sources by frequency.")
|
1148 |
| - unique_freqs = {src.source_time.freq0 for src in adj_srcs} |
| 1173 | + unique_freqs = {src.source_time.freq0 for src in adj_srcs_process_fwidth} |
1149 | 1174 | for freq0 in unique_freqs:
|
1150 |
| - group = [src for src in adj_srcs if src.source_time.freq0 == freq0] |
| 1175 | + group = [src for src in adj_srcs_process_fwidth if src.source_time.freq0 == freq0] |
1151 | 1176 | post_norm = xr.DataArray(data=np.array([1 + 0j]), coords={"f": [freq0]})
|
1152 | 1177 | adjoint_infos.append(
|
1153 | 1178 | AdjointSourceInfo(sources=group, post_norm=post_norm, normalize_sim=True)
|
@@ -1184,14 +1209,48 @@ def _process_adjoint_sources_broadband(
|
1184 | 1209 |
|
1185 | 1210 | return [src_broadband], post_norm_amps
|
1186 | 1211 |
|
| 1212 | + @staticmethod |
| 1213 | + def _adjoint_src_width_broadband(adj_srcs: list[SourceType]) -> float: |
| 1214 | + """Find the adjoint source fwidth that sufficiently covers all adjoint frequencies.""" |
| 1215 | + |
| 1216 | + adj_srcs_f0 = [adj_src.source_time.freq0 for adj_src in adj_srcs] |
| 1217 | + middle_f0 = 0.5 * (np.max(adj_srcs_f0) + np.min(adj_srcs_f0)) |
| 1218 | + min_f0 = np.min(adj_srcs_f0) |
| 1219 | + |
| 1220 | + # width of source to sufficiently decay by zero frequency |
| 1221 | + decay_by_f0_fwidth = middle_f0 / NUM_ADJOINT_FWIDTH_TO_ZERO |
| 1222 | + # width of source to sufficiently cover all adjoint frequencies |
| 1223 | + fwidth_to_min_f0 = (middle_f0 - min_f0) / NUM_ADJOINT_FWIDTH_TO_FMIN |
| 1224 | + |
| 1225 | + # log warning if the adjoint pulse width is not sufficiently decayed by zero frequency |
| 1226 | + # which may cause some issues in the adjoint accuracy when using field sources |
| 1227 | + if (fwidth_to_min_f0 > decay_by_f0_fwidth) and isinstance(adj_srcs[0], CustomCurrentSource): |
| 1228 | + log.warning( |
| 1229 | + "Adjoint source generated with a frequency spectrum that extends to or overlaps with 0 Hz. " |
| 1230 | + "This can introduce errors into the gradient computation." |
| 1231 | + ) |
| 1232 | + |
| 1233 | + print(f"source widths: {decay_by_f0_fwidth}, {fwidth_to_min_f0}") |
| 1234 | + |
| 1235 | + # Choose a wider pulse width in frequency especially when the min/max frequencies |
| 1236 | + # for the broadband pulse might be very close together |
| 1237 | + adj_src_fwidth = np.maximum(decay_by_f0_fwidth, fwidth_to_min_f0) |
| 1238 | + |
| 1239 | + return middle_f0, adj_src_fwidth |
| 1240 | + |
1187 | 1241 | def _make_broadband_source(self, adj_srcs: list[SourceType]) -> SourceType:
|
1188 | 1242 | """Make a broadband source for a set of adjoint sources."""
|
1189 | 1243 |
|
| 1244 | + adj_src_f0, adj_src_fwidth = self._adjoint_src_width_broadband(adj_srcs) |
| 1245 | + |
1190 | 1246 | source_index = self.simulation.normalize_index or 0
|
| 1247 | + |
1191 | 1248 | src_time_base = self.simulation.sources[source_index].source_time.updated_copy(
|
1192 | 1249 | amplitude=1.0, phase=0.0
|
1193 | 1250 | )
|
1194 |
| - src_broadband = adj_srcs[0].updated_copy(source_time=src_time_base) |
| 1251 | + src_broadband = adj_srcs[0].updated_copy( |
| 1252 | + source_time=src_time_base.updated_copy(freq0=adj_src_f0, fwidth=adj_src_fwidth) |
| 1253 | + ) |
1195 | 1254 |
|
1196 | 1255 | return src_broadband
|
1197 | 1256 |
|
|
0 commit comments