Skip to content

Commit 710ebbe

Browse files
committed
fix[adjoint]: fix tuple handling in autograd gradient calculations
1 parent 34f9a54 commit 710ebbe

File tree

3 files changed

+120
-27
lines changed

3 files changed

+120
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
- Bug in `LayerRefinementSpec` that refines grids outside the layer region when one in-plane dimension is of size infinity.
2020
- Querying tasks was sometimes erroring unexpectedly.
2121
- Fixed automatic creation of missing output directories.
22+
- Bug in handling of tuple-type gradients that could lead to empty tuples or failing gradient calculations when differentiating w.r.t. (for instance) `td.Box.center`.
2223

2324
## [2.8.0] - 2025-03-04
2425

tests/test_components/test_autograd.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,83 @@ def objective(*args):
813813
assert anp.all(grad != 0.0), "some gradients are 0"
814814

815815

816+
class TestTupleGrads:
817+
center0 = (0.0, 0.0, 0.0)
818+
size0 = (0.5, 1.0, 1.5)
819+
820+
@staticmethod
821+
def make_simulation(center: tuple, size: tuple) -> td.Simulation:
822+
wavelength = 1.0
823+
freq0 = td.C_0 / wavelength
824+
825+
src = td.PointDipole(
826+
center=(-1.4, 0, 0),
827+
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10),
828+
polarization="Ex",
829+
)
830+
831+
mnt = td.FieldMonitor(
832+
size=(0, 0, 1),
833+
center=(1.4, 0, 0),
834+
freqs=[freq0, freq0 + freq0 / 50],
835+
name="fields",
836+
)
837+
838+
scatterer = td.Structure(
839+
geometry=td.Box(center=center, size=size),
840+
medium=td.Medium(permittivity=3.0),
841+
)
842+
843+
return td.Simulation(
844+
size=(3, 3, 3),
845+
run_time=2e-13,
846+
structures=[scatterer],
847+
sources=[src],
848+
monitors=[mnt],
849+
boundary_spec=td.BoundarySpec.all_sides(td.PML()),
850+
grid_spec=td.GridSpec.auto(min_steps_per_wvl=30),
851+
)
852+
853+
@pytest.mark.parametrize("run_async", [False, True])
854+
@pytest.mark.parametrize("zero", [False, True])
855+
@pytest.mark.parametrize("local_gradient", [False, True])
856+
def test_zero_grad_tuple(self, use_emulated_run, run_async, zero, local_gradient, tmp_path):
857+
"""Checks that tuple gradients don't return empty tuples"""
858+
859+
def obj(center: tuple, size: tuple) -> float:
860+
sim = self.make_simulation(center=center, size=size)
861+
if run_async:
862+
batch_data = web.run_async(
863+
{"lossy_test_async": sim},
864+
path_dir=tmp_path,
865+
local_gradient=local_gradient,
866+
)
867+
sim_data = list(batch_data.values())[0]
868+
else:
869+
sim_data = web.run(
870+
sim,
871+
task_name="lossy_test",
872+
local_gradient=local_gradient,
873+
)
874+
objval = anp.mean(sim_data["fields"].intensity.data).item()
875+
if zero:
876+
objval *= 0
877+
return objval
878+
879+
d_power = ag.value_and_grad(obj, argnum=(0, 1))
880+
val, (dp_dcenter, dp_dsize) = d_power(self.center0, self.size0)
881+
882+
assert len(dp_dcenter) == 3
883+
assert len(dp_dsize) == 3
884+
885+
if zero:
886+
assert np.allclose(dp_dcenter, 0)
887+
assert np.allclose(dp_dsize, 0)
888+
else:
889+
assert not np.allclose(dp_dcenter, 0)
890+
assert not np.allclose(dp_dsize, 0)
891+
892+
816893
@pytest.mark.parametrize("structure_key, monitor_key", args)
817894
def test_autograd_async_some_zero_grad(use_emulated_run, structure_key, monitor_key):
818895
"""Test objective where only some simulations in batch have adjoint sources."""

tidy3d/web/api/autograd/autograd.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,10 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
648648
"simulation's output. If this is unexpected, please review your "
649649
"setup or contact customer support for assistance."
650650
)
651-
return {k: 0 * v for k, v in sim_fields_original.items()}
651+
return {
652+
k: (type(v)(0 * x for x in v) if isinstance(v, (list, tuple)) else 0 * v)
653+
for k, v in sim_fields_original.items()
654+
}
652655

653656
# Run adjoint simulations in batch
654657
task_names_adj = [f"{task_name}_adjoint_{i}" for i in range(len(sims_adj))]
@@ -670,17 +673,16 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
670673
)
671674
td.log.info("Completed local batch adjoint simulations")
672675

673-
# sum partial derivatives from each adjoint simulation
676+
# Process results from local gradient computation
677+
vjp_fields_dict = {}
674678
for task_name_adj, sim_data_adj in batch_data_adj.items():
675679
td.log.info(f"Processing VJP contribution from {task_name_adj}")
676-
vjp_fields = postprocess_adj(
680+
vjp_fields_dict[task_name_adj] = postprocess_adj(
677681
sim_data_adj=sim_data_adj,
678682
sim_data_orig=sim_data_orig,
679683
sim_data_fwd=sim_data_fwd,
680684
sim_fields_keys=sim_fields_keys,
681685
)
682-
for k, v in vjp_fields.items():
683-
vjp_traced_fields[k] = vjp_traced_fields.get(k, 0) + v
684686
else:
685687
td.log.info("Starting server-side batch of adjoint simulations ...")
686688

@@ -699,15 +701,24 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
699701
tname_adj: sim.updated_copy(simulation_type="autograd_bwd", deep=False)
700702
for tname_adj, sim in sims_adj_dict.items()
701703
}
702-
vjp_traced_fields_dict = _run_async_tidy3d_bwd(
704+
vjp_fields_dict = _run_async_tidy3d_bwd(
703705
simulations=sims_adj_dict,
704706
**run_kwargs,
705707
)
706708
td.log.info("Completed server-side batch of adjoint simulations.")
707709

708-
for fields in vjp_traced_fields_dict.values():
709-
for k, v in fields.items():
710-
vjp_traced_fields[k] = vjp_traced_fields.get(k, 0) + v
710+
# Accumulate gradients from all adjoint simulations
711+
for task_name_adj, vjp_fields in vjp_fields_dict.items():
712+
td.log.info(f"Processing VJP contribution from {task_name_adj}")
713+
for k, v in vjp_fields.items():
714+
if k in vjp_traced_fields:
715+
val = vjp_traced_fields[k]
716+
if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
717+
vjp_traced_fields[k] = type(val)(x + y for x, y in zip(val, v))
718+
else:
719+
vjp_traced_fields[k] += v
720+
else:
721+
vjp_traced_fields[k] = v
711722

712723
td.log.debug(f"Computed gradients for {len(vjp_traced_fields)} fields")
713724
return vjp_traced_fields
@@ -765,7 +776,8 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
765776
if not sims_adj:
766777
td.log.debug(f"Adjoint simulation for task '{task_name}' contains no sources.")
767778
sim_fields_vjp_dict[task_name] = {
768-
k: 0 * v for k, v in sim_fields_original_dict[task_name].items()
779+
k: (type(v)(0 * x for x in v) if isinstance(v, (list, tuple)) else 0 * v)
780+
for k, v in sim_fields_original_dict[task_name].items()
769781
}
770782
continue
771783

@@ -781,6 +793,9 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
781793
)
782794
return sim_fields_vjp_dict
783795

796+
# Dictionary to store VJP results from all adjoint simulations
797+
vjp_results = {}
798+
784799
if local_gradient:
785800
# Run all adjoint simulations in a single batch
786801
path_dir = Path(run_async_kwargs.pop("path_dir"))
@@ -791,28 +806,20 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
791806
all_sims_adj, path_dir=str(path_dir_adj), **run_async_kwargs
792807
)
793808

794-
# Process results for each original task
809+
# Process results for each adjoint task
795810
for adj_task_name, sim_data_adj in batch_data_adj.items():
796811
task_name = task_name_mapping[adj_task_name]
797812
sim_data_orig = sim_data_orig_dict[task_name]
798813
sim_data_fwd = sim_data_fwd_dict[task_name]
799814
sim_fields_keys = sim_fields_keys_dict[task_name]
800815

801816
# Compute VJP contribution
802-
sim_fields_vjp = postprocess_adj(
817+
vjp_results[adj_task_name] = postprocess_adj(
803818
sim_data_adj=sim_data_adj,
804819
sim_data_orig=sim_data_orig,
805820
sim_data_fwd=sim_data_fwd,
806821
sim_fields_keys=sim_fields_keys,
807822
)
808-
809-
# Sum contributions for each original task
810-
if task_name in sim_fields_vjp_dict:
811-
for k, v in sim_fields_vjp.items():
812-
sim_fields_vjp_dict[task_name][k] += v
813-
else:
814-
sim_fields_vjp_dict[task_name] = sim_fields_vjp
815-
816823
else:
817824
# Set up parent tasks mapping for all adjoint simulations
818825
parent_tasks = {}
@@ -830,19 +837,27 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
830837
}
831838

832839
# Run all adjoint simulations in a single batch
833-
sim_fields_vjp_dict_adj = _run_async_tidy3d_bwd(
840+
vjp_results = _run_async_tidy3d_bwd(
834841
simulations=all_sims_adj,
835842
**run_async_kwargs,
836843
)
837844

838-
# Combine results for each original task
839-
for adj_task_name, fields in sim_fields_vjp_dict_adj.items():
840-
task_name = task_name_mapping[adj_task_name]
841-
if task_name in sim_fields_vjp_dict:
842-
for k, v in fields.items():
845+
# Accumulate gradients from all adjoint simulations
846+
for adj_task_name, vjp_fields in vjp_results.items():
847+
task_name = task_name_mapping[adj_task_name]
848+
849+
if task_name not in sim_fields_vjp_dict:
850+
sim_fields_vjp_dict[task_name] = {}
851+
852+
for k, v in vjp_fields.items():
853+
if k in sim_fields_vjp_dict[task_name]:
854+
val = sim_fields_vjp_dict[task_name][k]
855+
if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
856+
sim_fields_vjp_dict[task_name][k] = type(val)(x + y for x, y in zip(val, v))
857+
else:
843858
sim_fields_vjp_dict[task_name][k] += v
844859
else:
845-
sim_fields_vjp_dict[task_name] = fields
860+
sim_fields_vjp_dict[task_name][k] = v
846861

847862
return sim_fields_vjp_dict
848863

0 commit comments

Comments
 (0)