Skip to content

Commit 9a35483

Browse files
committed
feat: Support for adjoint simulations with no sources (zero gradients)
1 parent bf1645c commit 9a35483

File tree

5 files changed

+133
-34
lines changed

5 files changed

+133
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Differentiable `smooth_min`, `smooth_max`, and `least_squares` functions in `tidy3d.plugins.autograd`.
1414
- Differential operators `grad` and `value_and_grad` in `tidy3d.plugins.autograd` that behave similarly to the autograd operators but support auxiliary data via `aux_data=True` as well as differentiation w.r.t. `DataArray`.
1515
- `@scalar_objective` decorator in `tidy3d.plugins.autograd` that wraps objective functions to ensure they return a scalar value and performs additional checks to ensure compatibility of objective functions with autograd. Used by default in `tidy3d.plugins.autograd.value_and_grad` as well as `tidy3d.plugins.autograd.grad`.
16+
- Autograd support for simulations without adjoint sources in `run` as well as `run_async`, which will not attempt to run the simulation but instead return zero gradients. This can sometimes occur if the objective function gradient does not depend on some simulations, for example when using `min` or `max` in the objective.
1617

1718

1819
### Changed
@@ -25,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2526
- `xarray` 2024.10.0 compatibility for autograd.
2627
- Some failing examples in the expressions plugin documentation.
2728
- Inaccuracy in transforming gradients from edge to `PolySlab.vertices`.
29+
- Bug in `run_async` where an adjoint simulation would sometimes be assigned to the wrong forward simulation.
30+
2831

2932
## [2.7.6] - 2024-10-30
3033

tests/test_components/test_autograd.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,29 @@ def task_name_fn(i: int, sign: int) -> str:
673673
print(f"avg(diff(objectives)) = {diff_objectives_num:.4f}")
674674

675675

676+
def test_run_zero_grad(use_emulated_run, log_capture):
677+
"""Test warning if no adjoint sim is run (no adjoint sources).
678+
679+
This checks the case where a simulation is still part of the computational
680+
graph (i.e. the output technically depends on the simulation),
681+
but no adjoint sources are placed because their amplitudes are zero and thus
682+
no adjoint simulation is run.
683+
"""
684+
685+
# only needs to be checked for one monitor
686+
fn_dict = get_functions(args[0][0], args[0][1])
687+
make_sim = fn_dict["sim"]
688+
postprocess = fn_dict["postprocess"]
689+
690+
def objective(*args):
691+
sim = make_sim(*args)
692+
sim_data = run(sim, task_name="adjoint_test", verbose=False)
693+
return 0 * postprocess(sim_data)
694+
695+
with AssertLogLevel(log_capture, "WARNING", contains_str="no sources"):
696+
grad = ag.grad(objective)(params0)
697+
698+
676699
@pytest.mark.parametrize("structure_key, monitor_key", args)
677700
def test_autograd_objective(use_emulated_run, structure_key, monitor_key):
678701
"""Test an objective function through tidy3d autograd."""
@@ -717,8 +740,6 @@ def test_autograd_async(use_emulated_run, structure_key, monitor_key):
717740
task_names = {"1", "2", "3", "4"}
718741

719742
def objective(*args):
720-
"""Objective function."""
721-
722743
sims = {task_name: make_sim(*args) for task_name in task_names}
723744
batch_data = run_async(sims, verbose=False)
724745
value = 0.0
@@ -731,6 +752,51 @@ def objective(*args):
731752
assert anp.all(grad != 0.0), "some gradients are 0"
732753

733754

755+
@pytest.mark.parametrize("structure_key, monitor_key", args)
756+
def test_autograd_async_some_zero_grad(use_emulated_run, log_capture, structure_key, monitor_key):
757+
"""Test objective where only some simulations in batch have adjoint sources."""
758+
759+
fn_dict = get_functions(structure_key, monitor_key)
760+
make_sim = fn_dict["sim"]
761+
postprocess = fn_dict["postprocess"]
762+
763+
task_names = {"1", "2", "3", "4"}
764+
765+
def objective(*args):
766+
sims = {task_name: make_sim(*args) for task_name in task_names}
767+
batch_data = run_async(sims, verbose=False)
768+
values = []
769+
for _, sim_data in batch_data.items():
770+
values.append(postprocess(sim_data))
771+
return min(values)
772+
773+
# with AssertLogLevel(log_capture, "DEBUG", contains_str="no sources"):
774+
val, grad = ag.value_and_grad(objective)(params0)
775+
776+
assert anp.all(grad != 0.0), "some gradients are 0"
777+
778+
779+
def test_autograd_async_all_zero_grad(use_emulated_run, log_capture):
780+
"""Test objective where no simulation in batch has adjoint sources."""
781+
782+
fn_dict = get_functions(args[0][0], args[0][1])
783+
make_sim = fn_dict["sim"]
784+
postprocess = fn_dict["postprocess"]
785+
786+
task_names = {"1", "2", "3", "4"}
787+
788+
def objective(*args):
789+
sims = {task_name: make_sim(*args) for task_name in task_names}
790+
batch_data = run_async(sims, verbose=False)
791+
values = []
792+
for _, sim_data in batch_data.items():
793+
values.append(postprocess(sim_data))
794+
return 0 * sum(values)
795+
796+
with AssertLogLevel(log_capture, "WARNING", contains_str="contains adjoint sources"):
797+
grad = ag.grad(objective)(params0)
798+
799+
734800
def test_autograd_speed_num_structures(use_emulated_run):
735801
"""Test an objective function through tidy3d autograd."""
736802

tidy3d/components/data/sim_data.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,8 +1033,10 @@ def split_original_fwd(self, num_mnts_original: int) -> Tuple[SimulationData, Si
10331033
return sim_data_original, sim_data_fwd
10341034

10351035
def make_adjoint_sim(
1036-
self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor]
1037-
) -> Simulation:
1036+
self,
1037+
data_vjp_paths: set[tuple],
1038+
adjoint_monitors: list[Monitor],
1039+
) -> Simulation | None:
10381040
"""Make the adjoint simulation from the original simulation and the VJP-containing data."""
10391041

10401042
sim_original = self.simulation
@@ -1045,6 +1047,9 @@ def make_adjoint_sim(
10451047
for src_list in sources_adj_dict.values():
10461048
adj_srcs += list(src_list)
10471049

1050+
if not any(adj_srcs):
1051+
return None
1052+
10481053
adjoint_source_info = self.process_adjoint_sources(adj_srcs=adj_srcs)
10491054

10501055
# grab boundary conditions with flipped Bloch vectors (for adjoint)
@@ -1087,14 +1092,6 @@ def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceTy
10871092
)
10881093
sources_adj_all[mnt_data.monitor.name] = sources_adj
10891094

1090-
if not any(src for _, src in sources_adj_all.items()):
1091-
raise ValueError(
1092-
"No adjoint sources created for this simulation. "
1093-
"This could indicate a bug in your setup, for example the objective function "
1094-
"output depending on a monitor that is not supported. If you encounter this error, "
1095-
"please examine your set up or contact customer support if you need more help."
1096-
)
1097-
10981095
return sources_adj_all
10991096

11001097
@property

tidy3d/components/source.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def frequency_range(self, num_fwidth: float = 4.0) -> FreqBound:
101101
"""Frequency range within plus/minus ``num_fwidth * fwidth`` of the central frequency."""
102102

103103
@abstractmethod
104-
def end_time(self) -> float | None:
104+
def end_time(self) -> Optional[float]:
105105
"""Time after which the source is effectively turned off / close to zero amplitude."""
106106

107107

@@ -192,7 +192,7 @@ def amp_time(self, time: float) -> complex:
192192

193193
return pulse_amp
194194

195-
def end_time(self) -> float | None:
195+
def end_time(self) -> Optional[float]:
196196
"""Time after which the source is effectively turned off / close to zero amplitude."""
197197

198198
# TODO: decide if we should continue to return an end_time if the DC component remains
@@ -251,7 +251,7 @@ def amp_time(self, time: float) -> complex:
251251

252252
return const * offset * oscillation * amp
253253

254-
def end_time(self) -> float | None:
254+
def end_time(self) -> Optional[float]:
255255
"""Time after which the source is effectively turned off / close to zero amplitude."""
256256
return None
257257

@@ -420,7 +420,7 @@ def amp_time(self, time: float) -> complex:
420420

421421
return offset * oscillation * amp * envelope
422422

423-
def end_time(self) -> float | None:
423+
def end_time(self) -> Optional[float]:
424424
"""Time after which the source is effectively turned off / close to zero amplitude."""
425425

426426
if self.source_time_dataset is None:

tidy3d/web/api/autograd/autograd.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import tidy3d as td
1414
from tidy3d.components.autograd import AutogradFieldMap, get_static
1515
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
16-
from tidy3d.components.data.sim_data import AdjointSourceInfo
1716

1817
from ...core.s3utils import download_file, upload_file
1918
from ..asynchronous import DEFAULT_DATA_DIR
@@ -289,7 +288,10 @@ def run_async(
289288

290289

291290
def _run(
292-
simulation: td.Simulation, task_name: str, local_gradient: bool = LOCAL_GRADIENT, **run_kwargs
291+
simulation: td.Simulation,
292+
task_name: str,
293+
local_gradient: bool = LOCAL_GRADIENT,
294+
**run_kwargs,
293295
) -> td.SimulationData:
294296
"""User-facing ``web.run`` function, compatible with ``autograd`` differentiation."""
295297

@@ -323,7 +325,9 @@ def _run(
323325

324326

325327
def _run_async(
326-
simulations: dict[str, td.Simulation], local_gradient: bool = LOCAL_GRADIENT, **run_async_kwargs
328+
simulations: dict[str, td.Simulation],
329+
local_gradient: bool = LOCAL_GRADIENT,
330+
**run_async_kwargs,
327331
) -> dict[str, td.SimulationData]:
328332
"""User-facing ``web.run_async`` function, compatible with ``autograd`` differentiation."""
329333

@@ -596,6 +600,15 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
596600
sim_fields_keys=sim_fields_keys,
597601
)
598602

603+
if sim_adj is None:
604+
td.log.warning(
605+
f"Adjoint simulation for task '{task_name}' contains no sources. "
606+
"This can occur if the objective function does not depend on the "
607+
"simulation's output. If this is unexpected, please review your "
608+
"setup or contact customer support for assistance."
609+
)
610+
return {k: 0 * v for k, v in sim_fields_original.items()}
611+
599612
# run adjoint simulation
600613
task_name_adj = str(task_name) + "_adjoint"
601614

@@ -656,9 +669,10 @@ def _run_async_bwd(
656669
def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, AutogradFieldMap]:
657670
"""dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}"""
658671

659-
task_names_adj = {task_name + "_adjoint" for task_name in task_names}
672+
task_names_adj = [task_name + "_adjoint" for task_name in task_names]
660673

661674
sims_adj = {}
675+
sim_fields_vjp_dict = {}
662676
for task_name, task_name_adj in zip(task_names, task_names_adj):
663677
data_fields_vjp = data_fields_dict_vjp[task_name]
664678
sim_data_orig = sim_data_orig_dict[task_name]
@@ -669,48 +683,64 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
669683
sim_data_orig=sim_data_orig,
670684
sim_fields_keys=sim_fields_keys,
671685
)
686+
687+
if sim_adj is None:
688+
td.log.debug(f"Adjoint simulation for task '{task_name}' contains no sources. ")
689+
sim_fields_vjp_dict[task_name] = {
690+
k: 0 * v for k, v in sim_fields_original_dict[task_name].items()
691+
}
672692
sims_adj[task_name_adj] = sim_adj
673-
# TODO: handle case where no adjoint sources?
693+
694+
sims_to_run = {k: v for k, v in sims_adj.items() if v is not None}
695+
696+
if not sims_to_run:
697+
td.log.warning(
698+
"No simulation in batch contains adjoint sources and thus all gradients are zero. "
699+
"This likely indicates an issue with your setup, consider double-checking or contact support."
700+
)
701+
return sim_fields_vjp_dict
702+
703+
task_names_adj = list(sims_to_run.keys())
704+
task_names_fwd = [name.rstrip("_adjoint") for name in task_names_adj]
674705

675706
if local_gradient:
676707
# run adjoint simulation
677-
batch_data_adj, _ = _run_async_tidy3d(sims_adj, **run_async_kwargs)
708+
batch_data_adj, _ = _run_async_tidy3d(sims_to_run, **run_async_kwargs)
678709

679-
sim_fields_vjp_dict = {}
680-
for task_name, task_name_adj in zip(task_names, task_names_adj):
681-
sim_data_adj = batch_data_adj[task_name_adj]
710+
for task_name, task_name_adj in zip(task_names_fwd, task_names_adj):
682711
sim_data_orig = sim_data_orig_dict[task_name]
683712
sim_data_fwd = sim_data_fwd_dict[task_name]
684713
sim_fields_keys = sim_fields_keys_dict[task_name]
685714

715+
sim_data_adj = batch_data_adj.get(task_name_adj)
716+
686717
sim_fields_vjp = postprocess_adj(
687718
sim_data_adj=sim_data_adj,
688719
sim_data_orig=sim_data_orig,
689720
sim_data_fwd=sim_data_fwd,
690721
sim_fields_keys=sim_fields_keys,
691722
)
692-
sim_fields_vjp_dict[task_name] = sim_fields_vjp
693723

724+
sim_fields_vjp_dict[task_name] = sim_fields_vjp
694725
else:
695726
parent_tasks = {}
696-
for task_name_fwd, task_name_adj in zip(task_names, task_names_adj):
727+
for task_name_fwd, task_name_adj in zip(task_names_fwd, task_names_adj):
697728
task_id_fwd = aux_data_dict[task_name_fwd][AUX_KEY_FWD_TASK_ID]
698729
parent_tasks[task_name_adj] = [task_id_fwd]
699730

700731
run_async_kwargs["parent_tasks"] = parent_tasks
701732
run_async_kwargs["simulation_type"] = "autograd_bwd"
702-
sims_adj = {
733+
simulations = {
703734
task_name: sim.updated_copy(simulation_type="autograd_bwd", deep=False)
704-
for task_name, sim in sims_adj.items()
735+
for task_name, sim in sims_to_run.items()
705736
}
706737
sim_fields_vjp_dict_adj_keys = _run_async_tidy3d_bwd(
707-
simulations=sims_adj,
738+
simulations=simulations,
708739
**run_async_kwargs,
709740
)
710741

711742
# swap adjoint task_names for original task_names
712-
sim_fields_vjp_dict = {}
713-
for task_name_fwd, task_name_adj in zip(task_names, task_names_adj):
743+
for task_name_fwd, task_name_adj in zip(task_names_fwd, task_names_adj):
714744
sim_fields_vjp_dict[task_name_fwd] = sim_fields_vjp_dict_adj_keys[task_name_adj]
715745

716746
return sim_fields_vjp_dict
@@ -722,7 +752,7 @@ def setup_adj(
722752
data_fields_vjp: AutogradFieldMap,
723753
sim_data_orig: td.SimulationData,
724754
sim_fields_keys: list[tuple],
725-
) -> tuple[td.Simulation, AdjointSourceInfo]:
755+
) -> typing.Optional[td.Simulation]:
726756
"""Construct an adjoint simulation from a set of data_fields for the VJP."""
727757

728758
td.log.info("Running custom vjp (adjoint) pipeline.")
@@ -742,8 +772,11 @@ def setup_adj(
742772
]
743773

744774
sim_adj = sim_data_vjp.make_adjoint_sim(
745-
data_vjp_paths=data_vjp_paths, adjoint_monitors=adjoint_monitors
775+
data_vjp_paths=data_vjp_paths,
776+
adjoint_monitors=adjoint_monitors,
746777
)
778+
if sim_adj is None:
779+
return sim_adj
747780

748781
if _INSPECT_ADJOINT_FIELDS:
749782
adj_fld_mnt = td.FieldMonitor(

0 commit comments

Comments
 (0)