Skip to content

Commit 994bdc2

Browse files
committed
refc[adjoint]: Refactor PoleResidue derivative calculation
The derivative logic for both `PoleResidue` and the spatially-varying `CustomPoleResidue` has been refactored to use an analytical formula. This is achieved by: - Replacing the `autograd`-based implementations in both classes with the analytical formulas for the pole-residue model derivatives. - Creating a single, shared `staticmethod` (`_get_vjps_from_params`) in the base `PoleResidue` class to house this logic. - Updating the `_compute_derivatives` methods in both classes to be simple wrappers that call this new helper.
1 parent de0c8d5 commit 994bdc2

File tree

3 files changed

+60
-103
lines changed

3 files changed

+60
-103
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- Add support for `np.unwrap` in `tidy3d.plugins.autograd`.
1212
- Add Nunley variant to germanium material library based on Nunley et al. 2016 data.
1313

14+
### Changed
15+
- Switched to an analytical gradient calculation for spatially-varying pole-residue models (`CustomPoleResidue`).
16+
1417
### Fixed
1518
- Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window.
1619
- Bug in `PlaneWave` defined with a negative `angle_theta` which would lead to wrong injection.
@@ -101,8 +104,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
101104

102105
### Fixed
103106
- Fixed `reverse` property of `td.Scene.plot_structures_property()` to also reverse the colorbar.
104-
105-
### Fixed
106107
- Fixed bug in surface gradient computation where fields, instead of gradients, were being summed in frequency.
107108

108109
## [2.8.2] - 2025-04-09

tests/test_components/test_autograd.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,18 +1650,18 @@ def test_custom_pole_residue(monkeypatch):
16501650
custom_med_pole_res = td.CustomPoleResidue(eps_inf=eps_inf, poles=poles)
16511651

16521652
def J(eps):
1653-
return anp.sum(abs(eps))
1653+
return anp.sum(anp.abs(eps))
16541654

16551655
freq = 3e8
16561656
pr = td.CustomPoleResidue(eps_inf=eps_inf, poles=poles)
16571657
eps0 = pr.eps_model(freq)
16581658

1659-
dJ_deps = ag.holomorphic_grad(J)(eps0)
1659+
dJ_deps = np.conj(ag.holomorphic_grad(J)(eps0))
16601660

16611661
monkeypatch.setattr(
16621662
td.CustomPoleResidue,
16631663
"_derivative_field_cmp",
1664-
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps,
1664+
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps / 3.0,
16651665
)
16661666

16671667
import importlib
@@ -1703,17 +1703,18 @@ def f(eps_inf, poles):
17031703
eps = td.CustomPoleResidue._eps_model(eps_inf, poles, freq)
17041704
return J(eps)
17051705

1706-
gfn = ag.holomorphic_grad(f, argnum=(0, 1))
1707-
with warnings.catch_warnings():
1708-
warnings.simplefilter("ignore")
1709-
grad_eps_inf, grad_poles = gfn(eps_inf.values, poles_complex)
1706+
gfn = ag.grad(lambda x: f(x, poles_complex))
1707+
grad_eps_inf = gfn(eps_inf.values)
17101708

17111709
assert np.allclose(grads_computed[("eps_inf",)], grad_eps_inf)
17121710

1711+
gfn = ag.holomorphic_grad(lambda x: f(eps_inf.values, x))
1712+
grad_poles = gfn(poles_complex)
1713+
17131714
for i in range(len(poles)):
17141715
for j in range(2):
17151716
field_path = ("poles", i, j)
1716-
assert np.allclose(grads_computed[field_path], grad_poles[i][j])
1717+
assert np.allclose(grads_computed[field_path], np.conj(grad_poles[i][j]))
17171718

17181719

17191720
# @pytest.mark.timeout(18.0)

tidy3d/components/medium.py

Lines changed: 48 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
from __future__ import annotations
44

55
import functools
6-
import warnings
76
from abc import ABC, abstractmethod
87
from math import isclose
98
from typing import Callable, Optional, Union
109

11-
import autograd as ag
1210
import autograd.numpy as np
1311

1412
# TODO: it's hard to figure out which functions need this, for now all get it
@@ -3425,50 +3423,52 @@ def loss_upper_bound(self) -> float:
34253423
ep = ep[~np.isnan(ep)]
34263424
return max(ep.imag)
34273425

3426+
@staticmethod
3427+
def _get_vjps_from_params(
3428+
dJ_deps_complex: Union[complex, np.ndarray],
3429+
poles_vals: list[tuple[Union[complex, np.ndarray], Union[complex, np.ndarray]]],
3430+
omega: float,
3431+
requested_paths: list[tuple],
3432+
) -> AutogradFieldMap:
3433+
"""
3434+
Static helper to compute VJPs from parameters using the analytical chain rule.
3435+
"""
3436+
jw = 1j * omega
3437+
vjps = {}
3438+
3439+
if ("eps_inf",) in requested_paths:
3440+
vjps[("eps_inf",)] = np.real(dJ_deps_complex)
3441+
3442+
for i, (a_val, c_val) in enumerate(poles_vals):
3443+
if any(path[1] == i for path in requested_paths if path[0] == "poles"):
3444+
if ("poles", i, 0) in requested_paths:
3445+
deps_da = c_val / (jw + a_val) ** 2
3446+
dJ_da = dJ_deps_complex * deps_da
3447+
vjps[("poles", i, 0)] = dJ_da
3448+
if ("poles", i, 1) in requested_paths:
3449+
deps_dc = -1 / (jw + a_val)
3450+
dJ_dc = dJ_deps_complex * deps_dc
3451+
vjps[("poles", i, 1)] = dJ_dc
3452+
3453+
return vjps
3454+
34283455
def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
3429-
"""Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D."""
3456+
"""Compute adjoint derivatives by preparing scalar data and calling the static helper."""
34303457

3431-
# compute all derivatives beforehand
3432-
dJ_deps = self._derivative_eps_complex_volume(
3458+
dJ_deps_complex = self._derivative_eps_complex_volume(
34333459
E_der_map=derivative_info.E_der_map,
34343460
bounds=derivative_info.bounds,
34353461
freqs=np.atleast_1d(derivative_info.frequency),
34363462
)
34373463

3438-
dJ_deps = complex(dJ_deps)
3439-
3440-
# TODO: fix for multi-frequency
3441-
frequency = derivative_info.frequency
3442-
poles_complex = [(complex(a), complex(c)) for a, c in self.poles]
3443-
poles_complex = np.stack(poles_complex, axis=0)
3444-
3445-
# compute gradients of eps_model with respect to eps_inf and poles
3446-
grad_eps_model = ag.holomorphic_grad(self._eps_model, argnum=(0, 1))
3447-
with warnings.catch_warnings():
3448-
# ignore warnings about holmorphic grad being passed a non-complex input (poles)
3449-
warnings.simplefilter("ignore")
3450-
deps_deps_inf, deps_dpoles = grad_eps_model(
3451-
complex(self.eps_inf), poles_complex, complex(frequency)
3452-
)
3453-
3454-
# multiply with partial dJ/deps to give full gradients
3455-
3456-
dJ_deps_inf = dJ_deps * deps_deps_inf
3457-
dJ_dpoles = [(dJ_deps * a, dJ_deps * c) for a, c in deps_dpoles]
3458-
3459-
# get vjps w.r.t. permittivity and conductivity of the bulk
3460-
derivative_map = {}
3461-
for field_path in derivative_info.paths:
3462-
field_name, *rest = field_path
3464+
poles_vals = [(complex(a), complex(c)) for a, c in self.poles]
34633465

3464-
if field_name == "eps_inf":
3465-
derivative_map[field_path] = float(np.real(dJ_deps_inf))
3466-
3467-
elif field_name == "poles":
3468-
pole_index, a_or_c = rest
3469-
derivative_map[field_path] = complex(dJ_dpoles[pole_index][a_or_c])
3470-
3471-
return derivative_map
3466+
return self._get_vjps_from_params(
3467+
dJ_deps_complex=complex(dJ_deps_complex),
3468+
poles_vals=poles_vals,
3469+
omega=2 * np.pi * derivative_info.frequency,
3470+
requested_paths=derivative_info.paths,
3471+
)
34723472

34733473
@classmethod
34743474
def _real_partial_fraction_decomposition(
@@ -3903,73 +3903,28 @@ def _sel_custom_data_inside(self, bounds: Bound):
39033903
return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced)
39043904

39053905
def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
3906-
"""Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D."""
3906+
"""Compute adjoint derivatives by preparing array data and calling the static helper."""
39073907

3908-
dJ_deps = 0.0
3908+
dJ_deps_complex = 0.0
39093909
for dim in "xyz":
3910-
dJ_deps += self._derivative_field_cmp(
3910+
dJ_deps_complex += self._derivative_field_cmp(
39113911
E_der_map=derivative_info.E_der_map,
39123912
eps_data=self.eps_inf,
39133913
dim=dim,
39143914
freqs=np.atleast_1d(derivative_info.frequency),
39153915
)
39163916

3917-
# TODO: fix for multi-frequency
3918-
frequency = derivative_info.frequency
3919-
3920-
poles_complex = [
3917+
poles_vals = [
39213918
(np.array(a.values, dtype=complex), np.array(c.values, dtype=complex))
39223919
for a, c in self.poles
39233920
]
3924-
poles_complex = np.stack(poles_complex, axis=0)
3925-
3926-
def eps_model_r(
3927-
eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float
3928-
) -> float:
3929-
"""Real part of ``eps_model`` evaluated on ``self`` fields."""
3930-
return np.real(self._eps_model(eps_inf, poles, frequency))
3931-
3932-
def eps_model_i(
3933-
eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float
3934-
) -> float:
3935-
"""Real part of ``eps_model`` evaluated on ``self`` fields."""
3936-
return np.imag(self._eps_model(eps_inf, poles, frequency))
3937-
3938-
# compute the gradients w.r.t. each real and imaginary parts for eps_inf and poles
3939-
grad_eps_model_r = ag.elementwise_grad(eps_model_r, argnum=(0, 1))
3940-
grad_eps_model_i = ag.elementwise_grad(eps_model_i, argnum=(0, 1))
3941-
deps_deps_inf_r, deps_dpoles_r = grad_eps_model_r(
3942-
self.eps_inf.values, poles_complex, frequency
3943-
)
3944-
deps_deps_inf_i, deps_dpoles_i = grad_eps_model_i(
3945-
self.eps_inf.values, poles_complex, frequency
3946-
)
3947-
3948-
# multiply with dJ_deps partial derivative to give full gradients
3949-
3950-
deps_deps_inf = deps_deps_inf_r + 1j * deps_deps_inf_i
3951-
dJ_deps_inf = dJ_deps * deps_deps_inf / 3.0 # mysterious 3
39523921

3953-
dJ_dpoles = []
3954-
for (da_r, dc_r), (da_i, dc_i) in zip(deps_dpoles_r, deps_dpoles_i):
3955-
da = da_r + 1j * da_i
3956-
dc = dc_r + 1j * dc_i
3957-
dJ_da = dJ_deps * da / 2.0 # mysterious 2
3958-
dJ_dc = dJ_deps * dc / 2.0 # mysterious 2
3959-
dJ_dpoles.append((dJ_da, dJ_dc))
3960-
3961-
derivative_map = {}
3962-
for field_path in derivative_info.paths:
3963-
field_name, *rest = field_path
3964-
3965-
if field_name == "eps_inf":
3966-
derivative_map[field_path] = np.real(dJ_deps_inf)
3967-
3968-
elif field_name == "poles":
3969-
pole_index, a_or_c = rest
3970-
derivative_map[field_path] = dJ_dpoles[pole_index][a_or_c]
3971-
3972-
return derivative_map
3922+
return PoleResidue._get_vjps_from_params(
3923+
dJ_deps_complex=dJ_deps_complex,
3924+
poles_vals=poles_vals,
3925+
omega=2 * np.pi * derivative_info.frequency,
3926+
requested_paths=derivative_info.paths,
3927+
)
39733928

39743929

39753930
class Sellmeier(DispersiveMedium):

0 commit comments

Comments
 (0)