Skip to content

Commit a760d02

Browse files
committed
refc: Make adjoint helpers private
1 parent 34f2a48 commit a760d02

File tree

13 files changed

+108
-108
lines changed

13 files changed

+108
-108
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Changed
1515
- Relaxed bounds checking of path integrals during `WavePort` validation.
16+
- Internal adjoint helper methods are now prefixed with an underscore to separate them from the public API.
1617

1718
## [2.8.4] - 2025-05-15
1819

tests/test_components/test_autograd.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def emulated_run_fwd(simulation, task_name, **run_kwargs) -> td.SimulationData:
175175
sim_original = simulation
176176
sim_fields_keys = run_kwargs["sim_fields_keys"]
177177
# add gradient monitors and make combined simulation
178-
sim_combined = sim_original.with_adjoint_monitors(sim_fields_keys)
178+
sim_combined = sim_original._with_adjoint_monitors(sim_fields_keys)
179179
sim_data_combined = run_emulated(sim_combined, task_name=task_name)
180180

181181
# store both original and fwd data aux_data
@@ -1099,12 +1099,12 @@ def objective(*params):
10991099

11001100
sim_full_static = sim_full_traced.to_static()
11011101

1102-
sim_fields = sim_full_traced.strip_traced_fields()
1102+
sim_fields = sim_full_traced._strip_traced_fields()
11031103

11041104
# note: there is one traced structure in SIM_FULL already with 6 fields + 1 = 7
11051105
assert len(sim_fields) == 10
11061106

1107-
sim_traced = sim_full_static.insert_traced_fields(sim_fields)
1107+
sim_traced = sim_full_static._insert_traced_fields(sim_fields)
11081108

11091109
assert sim_traced == sim_full_traced
11101110

@@ -1135,7 +1135,7 @@ def test_sim_fields_io(structure_key, tmp_path):
11351135
s = make_structures(params0)[structure_key]
11361136
s = s.updated_copy(geometry=s.geometry.updated_copy(center=(2, 2, 2), size=(0, 0, 0)))
11371137
sim_full_traced = SIM_FULL.updated_copy(structures=list(SIM_FULL.structures) + [s])
1138-
sim_fields = sim_full_traced.strip_traced_fields()
1138+
sim_fields = sim_full_traced._strip_traced_fields()
11391139

11401140
field_map = FieldMap.from_autograd_field_map(sim_fields)
11411141
field_map_file = join(tmp_path, "test_sim_fields.hdf5.gz")
@@ -1434,7 +1434,7 @@ def J(eps):
14341434

14351435
monkeypatch.setattr(
14361436
td.PoleResidue,
1437-
"derivative_eps_complex_volume",
1437+
"_derivative_eps_complex_volume",
14381438
lambda self, E_der_map, bounds, freqs: dJ_deps,
14391439
)
14401440

@@ -1467,7 +1467,7 @@ def J(eps):
14671467
bounds_intersect=((-1, -1, -1), (1, 1, 1)),
14681468
)
14691469

1470-
grads_computed = pr.compute_derivatives(derivative_info=info)
1470+
grads_computed = pr._compute_derivatives(derivative_info=info)
14711471

14721472
def f(eps_inf, poles):
14731473
eps = td.PoleResidue._eps_model(eps_inf, poles, freq)
@@ -1549,7 +1549,7 @@ def J(eps):
15491549
bounds_intersect=((-1, -1, -1), (1, 1, 1)),
15501550
)
15511551

1552-
grads_computed = pr.compute_derivatives(derivative_info=info)
1552+
grads_computed = pr._compute_derivatives(derivative_info=info)
15531553

15541554
poles_complex = [
15551555
(np.array(a.values, dtype=complex), np.array(c.values, dtype=complex)) for a, c in poles
@@ -2141,7 +2141,7 @@ def objective(params):
21412141
structure_traced = make_structures(params)["medium"]
21422142
sim = SIM_BASE.updated_copy(structures=[structure_traced], monitors=monitors)
21432143
data = run(sim, task_name="adjoint_freq_test")
2144-
assert data.simulation.freqs_adjoint == [FREQ0]
2144+
assert data.simulation._freqs_adjoint == [FREQ0]
21452145
return anp.sum(data["field"].flux.values)
21462146

21472147
return objective

tidy3d/components/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def make_json_compatible(json_string: str) -> str:
948948
json_string = make_json_compatible(json_string)
949949
return json_string
950950

951-
def strip_traced_fields(
951+
def _strip_traced_fields(
952952
self, starting_path: tuple[str] = (), include_untraced_data_arrays: bool = False
953953
) -> AutogradFieldMap:
954954
"""Extract a dictionary mapping paths in the model to the data traced by ``autograd``.
@@ -1004,7 +1004,7 @@ def handle_value(x: Any, path: tuple[str, ...]) -> None:
10041004
# convert the resulting field_mapping to an autograd-traced dictionary
10051005
return dict_ag(field_mapping)
10061006

1007-
def insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Tidy3dBaseModel:
1007+
def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Tidy3dBaseModel:
10081008
"""Recursively insert a map of paths to autograd-traced fields into a copy of this obj."""
10091009

10101010
self_dict = self.dict()
@@ -1037,7 +1037,7 @@ def to_static(self) -> Tidy3dBaseModel:
10371037
"""Version of object with all autograd-traced fields removed."""
10381038

10391039
# get dictionary of all traced fields
1040-
field_mapping = self.strip_traced_fields()
1040+
field_mapping = self._strip_traced_fields()
10411041

10421042
# shortcut to just return self if no tracers found, for performance
10431043
if not field_mapping:
@@ -1047,7 +1047,7 @@ def to_static(self) -> Tidy3dBaseModel:
10471047
field_mapping_static = {key: get_static(val) for key, val in field_mapping.items()}
10481048

10491049
# insert the static values into a copy of self
1050-
return self.insert_traced_fields(field_mapping_static)
1050+
return self._insert_traced_fields(field_mapping_static)
10511051

10521052
@classmethod
10531053
def add_type_field(cls) -> None:

tidy3d/components/data/monitor_data.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _updated(self, update: Dict) -> MonitorData:
166166
data_dict.update(update)
167167
return type(self).parse_obj(data_dict)
168168

169-
def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[Source]:
169+
def _make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[Source]:
170170
"""Generate adjoint sources for this ``MonitorData`` instance."""
171171

172172
# TODO: if there's data in the MonitorData, but no adjoint source, then
@@ -1235,7 +1235,7 @@ def to_zbf(
12351235
e_flat = e.values.flatten(order="C")
12361236
# Interweave real and imaginary parts
12371237
e_values = np.ravel(np.column_stack((e_flat.real, e_flat.imag)))
1238-
fout.write(struct.pack(f"<{2 * n_x*n_y}d", *e_values))
1238+
fout.write(struct.pack(f"<{2 * n_x * n_y}d", *e_values))
12391239

12401240
return e_x, e_y
12411241

@@ -1333,7 +1333,7 @@ def to_source(
13331333
field_dataset=dataset, source_time=source_time, center=center, size=size, **kwargs
13341334
)
13351335

1336-
def make_adjoint_sources(
1336+
def _make_adjoint_sources(
13371337
self, dataset_names: list[str], fwidth: float
13381338
) -> List[CustomCurrentSource]:
13391339
"""Converts a :class:`.FieldData` to a list of adjoint current or point sources."""
@@ -2068,14 +2068,14 @@ def _check_fields_stored(self, components: list[EMField]):
20682068
"include the mode field profiles in the corresponding 'ModeData'."
20692069
)
20702070

2071-
def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[ModeSource]:
2071+
def _make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[ModeSource]:
20722072
"""Get all adjoint sources for the ``ModeMonitorData``."""
20732073

20742074
adjoint_sources = []
20752075

20762076
for name in dataset_names:
20772077
if name == "amps":
2078-
adjoint_sources += self.make_adjoint_sources_amps(fwidth=fwidth)
2078+
adjoint_sources += self._make_adjoint_sources_amps(fwidth=fwidth)
20792079
elif not np.all(self.n_complex.values == 0.0):
20802080
log.warning(
20812081
f"Can't create adjoint source for 'ModeData.{type(self)}.{name}'. "
@@ -2086,7 +2086,7 @@ def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[
20862086

20872087
return adjoint_sources
20882088

2089-
def make_adjoint_sources_amps(self, fwidth: float) -> list[ModeSource]:
2089+
def _make_adjoint_sources_amps(self, fwidth: float) -> list[ModeSource]:
20902090
"""Generate adjoint sources for ``ModeMonitorData.amps``."""
20912091

20922092
coords = self.amps.coords
@@ -2102,12 +2102,12 @@ def make_adjoint_sources_amps(self, fwidth: float) -> list[ModeSource]:
21022102
if self.get_amplitude(amp_single) == 0.0:
21032103
continue
21042104

2105-
adjoint_source = self.adjoint_source_amp(amp=amp_single, fwidth=fwidth)
2105+
adjoint_source = self._adjoint_source_amp(amp=amp_single, fwidth=fwidth)
21062106
adjoint_sources.append(adjoint_source)
21072107

21082108
return adjoint_sources
21092109

2110-
def adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
2110+
def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
21112111
"""Generate an adjoint ``ModeSource`` for a single amplitude."""
21122112

21132113
monitor = self.monitor
@@ -2266,7 +2266,7 @@ class FluxData(MonitorData):
22662266
..., title="Flux", description="Flux values in the frequency-domain."
22672267
)
22682268

2269-
def make_adjoint_sources(
2269+
def _make_adjoint_sources(
22702270
self, dataset_names: list[str], fwidth: float
22712271
) -> List[Union[CustomCurrentSource, PointDipole]]:
22722272
"""Converts a :class:`.FieldData` to a list of adjoint current or point sources."""
@@ -2606,7 +2606,7 @@ def radar_cross_section(self) -> DataArray:
26062606

26072607
return self.make_data_array(data=rcs_data)
26082608

2609-
def make_adjoint_sources(
2609+
def _make_adjoint_sources(
26102610
self, dataset_names: list[str], fwidth: float
26112611
) -> List[Union[CustomCurrentSource, PointDipole]]:
26122612
"""Error if server-side field projection is used for autograd"""
@@ -2764,7 +2764,7 @@ def _check_integration_suitability(self):
27642764
"There are not enough sampling points along `theta` or `phi` for accurate integration. "
27652765
f"Currently, {len(self.theta)} samples for `theta` and {len(self.phi)} samples for `phi`. "
27662766
f"Consider using, at the very least, {MIN_ANGULAR_SAMPLES_SPHERE} samples for `theta` and "
2767-
f"{2*MIN_ANGULAR_SAMPLES_SPHERE} samples for `phi`."
2767+
f"{2 * MIN_ANGULAR_SAMPLES_SPHERE} samples for `phi`."
27682768
)
27692769
self._check_coords_sorted(self.theta, "theta")
27702770
self._check_coords_sorted(self.phi, "phi")
@@ -3359,13 +3359,13 @@ def _make_dataset(self, fields: Tuple[np.ndarray, ...], keys: Tuple[str, ...]) -
33593359

33603360
""" Autograd code """
33613361

3362-
def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[PlaneWave]:
3362+
def _make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[PlaneWave]:
33633363
"""Get all adjoint sources for the ``DiffractionMonitor.amps``."""
33643364

33653365
# NOTE: everything just goes through `.amps`, any post-processing is encoded in E-fields
3366-
return self.make_adjoint_sources_amps(fwidth=fwidth)
3366+
return self._make_adjoint_sources_amps(fwidth=fwidth)
33673367

3368-
def make_adjoint_sources_amps(self, fwidth: float) -> list[PlaneWave]:
3368+
def _make_adjoint_sources_amps(self, fwidth: float) -> list[PlaneWave]:
33693369
"""Make adjoint sources for outputs that depend on DiffractionData.`amps`."""
33703370

33713371
amps = self.amps

tidy3d/components/data/sim_data.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ def source_spectrum_fn(freqs):
994994

995995
return self.copy(update=dict(simulation=simulation, data=data_normalized))
996996

997-
def split_adjoint_data(self: SimulationData, num_mnts_original: int) -> tuple[list, list]:
997+
def _split_adjoint_data(self: SimulationData, num_mnts_original: int) -> tuple[list, list]:
998998
"""Split data list into original, adjoint field, and adjoint permittivity."""
999999

10001000
data_all = list(self.data)
@@ -1008,11 +1008,11 @@ def split_adjoint_data(self: SimulationData, num_mnts_original: int) -> tuple[li
10081008

10091009
return data_original, data_adjoint
10101010

1011-
def split_original_fwd(self, num_mnts_original: int) -> Tuple[SimulationData, SimulationData]:
1011+
def _split_original_fwd(self, num_mnts_original: int) -> Tuple[SimulationData, SimulationData]:
10121012
"""Split this simulation data into original and fwd data from number of original mnts."""
10131013

10141014
# split the data and monitors into the original ones & adjoint gradient ones (for 'fwd')
1015-
data_original, data_fwd = self.split_adjoint_data(num_mnts_original=num_mnts_original)
1015+
data_original, data_fwd = self._split_adjoint_data(num_mnts_original=num_mnts_original)
10161016
monitors_orig, monitors_fwd = split_list(self.simulation.monitors, index=num_mnts_original)
10171017

10181018
# reconstruct the simulation data for the user, using original sim, and data for original mnts
@@ -1033,7 +1033,7 @@ def split_original_fwd(self, num_mnts_original: int) -> Tuple[SimulationData, Si
10331033

10341034
return sim_data_original, sim_data_fwd
10351035

1036-
def make_adjoint_sims(
1036+
def _make_adjoint_sims(
10371037
self,
10381038
data_vjp_paths: set[tuple],
10391039
adjoint_monitors: list[Monitor],
@@ -1046,15 +1046,15 @@ def make_adjoint_sims(
10461046
sim_original = self.simulation
10471047

10481048
# generate the adjoint sources {mnt_name : list[Source]}
1049-
sources_adj_dict = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths)
1049+
sources_adj_dict = self._make_adjoint_sources(data_vjp_paths=data_vjp_paths)
10501050
if not sources_adj_dict:
10511051
return []
10521052

10531053
adj_srcs = []
10541054
for src_list in sources_adj_dict.values():
10551055
adj_srcs += list(src_list)
10561056

1057-
adjoint_source_infos = self.process_adjoint_sources(adj_srcs=adj_srcs)
1057+
adjoint_source_infos = self._process_adjoint_sources(adj_srcs=adj_srcs)
10581058

10591059
if not adjoint_source_infos:
10601060
return []
@@ -1095,7 +1095,7 @@ def make_adjoint_sims(
10951095

10961096
return adj_sims
10971097

1098-
def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceType]:
1098+
def _make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceType]:
10991099
"""Generate all of the non-zero sources for the adjoint simulation given the VJP data."""
11001100

11011101
# map of index into 'self.data' to the list of datasets we need adjoint sources for
@@ -1107,8 +1107,8 @@ def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceTy
11071107
sources_adj_all = defaultdict(list)
11081108
for data_index, dataset_names in adj_src_map.items():
11091109
mnt_data = self.data[data_index]
1110-
sources_adj = mnt_data.make_adjoint_sources(
1111-
dataset_names=dataset_names, fwidth=self.fwidth_adj
1110+
sources_adj = mnt_data._make_adjoint_sources(
1111+
dataset_names=dataset_names, fwidth=self._fwidth_adj
11121112
)
11131113
sources_adj_all[mnt_data.monitor.name] = sources_adj
11141114
log.info(
@@ -1118,12 +1118,12 @@ def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceTy
11181118
return sources_adj_all
11191119

11201120
@property
1121-
def fwidth_adj(self) -> float:
1121+
def _fwidth_adj(self) -> float:
11221122
# fwidth of forward pass, try as default for adjoint
11231123
normalize_index_fwd = self.simulation.normalize_index or 0
11241124
return self.simulation.sources[normalize_index_fwd].source_time.fwidth
11251125

1126-
def process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSourceInfo]:
1126+
def _process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSourceInfo]:
11271127
"""Compute list of final sources along with a post run normalization for adj fields."""
11281128
# dictionary mapping hash of sources with same freq dependence to list of time-dependencies
11291129
hashes_to_sources = defaultdict(None)
@@ -1157,7 +1157,7 @@ def process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSou
11571157
for src_hash, src_times in hashes_to_src_times.items():
11581158
base_src = hashes_to_sources[src_hash]
11591159
group = [base_src.updated_copy(source_time=src_time) for src_time in src_times]
1160-
processed_srcs, post_norm = self.process_adjoint_sources_broadband(group)
1160+
processed_srcs, post_norm = self._process_adjoint_sources_broadband(group)
11611161
adjoint_infos.append(
11621162
AdjointSourceInfo(
11631163
sources=processed_srcs, post_norm=post_norm, normalize_sim=True
@@ -1167,7 +1167,7 @@ def process_adjoint_sources(self, adj_srcs: list[SourceType]) -> list[AdjointSou
11671167
log.info(f"Created {len(adjoint_infos)} adjoint source groups.")
11681168
return adjoint_infos
11691169

1170-
def process_adjoint_sources_broadband(
1170+
def _process_adjoint_sources_broadband(
11711171
self, adj_srcs: list[SourceType]
11721172
) -> tuple[list[SourceType], xr.DataArray]:
11731173
"""Process adjoint sources for the case of several sources at the same freq."""
@@ -1211,10 +1211,10 @@ def _make_post_norm_amps(adj_srcs: list[SourceType]) -> xr.DataArray:
12111211
amps_complex = np.array(amps_complex)
12121212
return xr.DataArray(amps_complex, coords=coords)
12131213

1214-
def get_adjoint_data(self, structure_index: int, data_type: str) -> MonitorDataType:
1214+
def _get_adjoint_data(self, structure_index: int, data_type: str) -> MonitorDataType:
12151215
"""Grab the field or permittivity data for a given structure index."""
12161216

1217-
monitor_name = Structure.get_monitor_name(index=structure_index, data_type=data_type)
1217+
monitor_name = Structure._get_monitor_name(index=structure_index, data_type=data_type)
12181218
return self[monitor_name]
12191219

12201220
def to_mat_file(self, fname: str, **kwargs):

0 commit comments

Comments
 (0)