|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | import functools
|
6 |
| -import warnings |
7 | 6 | from abc import ABC, abstractmethod
|
8 | 7 | from math import isclose
|
9 | 8 | from typing import Callable, Optional, Union
|
10 | 9 |
|
11 |
| -import autograd as ag |
12 | 10 | import autograd.numpy as np
|
13 | 11 |
|
14 | 12 | # TODO: it's hard to figure out which functions need this, for now all get it
|
@@ -3425,50 +3423,52 @@ def loss_upper_bound(self) -> float:
|
3425 | 3423 | ep = ep[~np.isnan(ep)]
|
3426 | 3424 | return max(ep.imag)
|
3427 | 3425 |
|
| 3426 | + @staticmethod |
| 3427 | + def _get_vjps_from_params( |
| 3428 | + dJ_deps_complex: Union[complex, np.ndarray], |
| 3429 | + poles_vals: list[tuple[Union[complex, np.ndarray], Union[complex, np.ndarray]]], |
| 3430 | + omega: float, |
| 3431 | + requested_paths: list[tuple], |
| 3432 | + ) -> AutogradFieldMap: |
| 3433 | + """ |
| 3434 | + Static helper to compute VJPs from parameters using the analytical chain rule. |
| 3435 | + """ |
| 3436 | + jw = 1j * omega |
| 3437 | + vjps = {} |
| 3438 | + |
| 3439 | + if ("eps_inf",) in requested_paths: |
| 3440 | + vjps[("eps_inf",)] = np.real(dJ_deps_complex) |
| 3441 | + |
| 3442 | + for i, (a_val, c_val) in enumerate(poles_vals): |
| 3443 | + if any(path[1] == i for path in requested_paths if path[0] == "poles"): |
| 3444 | + if ("poles", i, 0) in requested_paths: |
| 3445 | + deps_da = c_val / (jw + a_val) ** 2 |
| 3446 | + dJ_da = dJ_deps_complex * deps_da |
| 3447 | + vjps[("poles", i, 0)] = dJ_da |
| 3448 | + if ("poles", i, 1) in requested_paths: |
| 3449 | + deps_dc = -1 / (jw + a_val) |
| 3450 | + dJ_dc = dJ_deps_complex * deps_dc |
| 3451 | + vjps[("poles", i, 1)] = dJ_dc |
| 3452 | + |
| 3453 | + return vjps |
| 3454 | + |
3428 | 3455 | def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
|
3429 |
| - """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" |
| 3456 | + """Compute adjoint derivatives by preparing scalar data and calling the static helper.""" |
3430 | 3457 |
|
3431 |
| - # compute all derivatives beforehand |
3432 |
| - dJ_deps = self._derivative_eps_complex_volume( |
| 3458 | + dJ_deps_complex = self._derivative_eps_complex_volume( |
3433 | 3459 | E_der_map=derivative_info.E_der_map,
|
3434 | 3460 | bounds=derivative_info.bounds,
|
3435 | 3461 | freqs=np.atleast_1d(derivative_info.frequency),
|
3436 | 3462 | )
|
3437 | 3463 |
|
3438 |
| - dJ_deps = complex(dJ_deps) |
3439 |
| - |
3440 |
| - # TODO: fix for multi-frequency |
3441 |
| - frequency = derivative_info.frequency |
3442 |
| - poles_complex = [(complex(a), complex(c)) for a, c in self.poles] |
3443 |
| - poles_complex = np.stack(poles_complex, axis=0) |
3444 |
| - |
3445 |
| - # compute gradients of eps_model with respect to eps_inf and poles |
3446 |
| - grad_eps_model = ag.holomorphic_grad(self._eps_model, argnum=(0, 1)) |
3447 |
| - with warnings.catch_warnings(): |
3448 |
| - # ignore warnings about holmorphic grad being passed a non-complex input (poles) |
3449 |
| - warnings.simplefilter("ignore") |
3450 |
| - deps_deps_inf, deps_dpoles = grad_eps_model( |
3451 |
| - complex(self.eps_inf), poles_complex, complex(frequency) |
3452 |
| - ) |
3453 |
| - |
3454 |
| - # multiply with partial dJ/deps to give full gradients |
3455 |
| - |
3456 |
| - dJ_deps_inf = dJ_deps * deps_deps_inf |
3457 |
| - dJ_dpoles = [(dJ_deps * a, dJ_deps * c) for a, c in deps_dpoles] |
3458 |
| - |
3459 |
| - # get vjps w.r.t. permittivity and conductivity of the bulk |
3460 |
| - derivative_map = {} |
3461 |
| - for field_path in derivative_info.paths: |
3462 |
| - field_name, *rest = field_path |
| 3464 | + poles_vals = [(complex(a), complex(c)) for a, c in self.poles] |
3463 | 3465 |
|
3464 |
| - if field_name == "eps_inf": |
3465 |
| - derivative_map[field_path] = float(np.real(dJ_deps_inf)) |
3466 |
| - |
3467 |
| - elif field_name == "poles": |
3468 |
| - pole_index, a_or_c = rest |
3469 |
| - derivative_map[field_path] = complex(dJ_dpoles[pole_index][a_or_c]) |
3470 |
| - |
3471 |
| - return derivative_map |
| 3466 | + return self._get_vjps_from_params( |
| 3467 | + dJ_deps_complex=complex(dJ_deps_complex), |
| 3468 | + poles_vals=poles_vals, |
| 3469 | + omega=2 * np.pi * derivative_info.frequency, |
| 3470 | + requested_paths=derivative_info.paths, |
| 3471 | + ) |
3472 | 3472 |
|
3473 | 3473 | @classmethod
|
3474 | 3474 | def _real_partial_fraction_decomposition(
|
@@ -3903,73 +3903,28 @@ def _sel_custom_data_inside(self, bounds: Bound):
|
3903 | 3903 | return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced)
|
3904 | 3904 |
|
3905 | 3905 | def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
|
3906 |
| - """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" |
| 3906 | + """Compute adjoint derivatives by preparing array data and calling the static helper.""" |
3907 | 3907 |
|
3908 |
| - dJ_deps = 0.0 |
| 3908 | + dJ_deps_complex = 0.0 |
3909 | 3909 | for dim in "xyz":
|
3910 |
| - dJ_deps += self._derivative_field_cmp( |
| 3910 | + dJ_deps_complex += self._derivative_field_cmp( |
3911 | 3911 | E_der_map=derivative_info.E_der_map,
|
3912 | 3912 | eps_data=self.eps_inf,
|
3913 | 3913 | dim=dim,
|
3914 | 3914 | freqs=np.atleast_1d(derivative_info.frequency),
|
3915 | 3915 | )
|
3916 | 3916 |
|
3917 |
| - # TODO: fix for multi-frequency |
3918 |
| - frequency = derivative_info.frequency |
3919 |
| - |
3920 |
| - poles_complex = [ |
| 3917 | + poles_vals = [ |
3921 | 3918 | (np.array(a.values, dtype=complex), np.array(c.values, dtype=complex))
|
3922 | 3919 | for a, c in self.poles
|
3923 | 3920 | ]
|
3924 |
| - poles_complex = np.stack(poles_complex, axis=0) |
3925 |
| - |
3926 |
| - def eps_model_r( |
3927 |
| - eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float |
3928 |
| - ) -> float: |
3929 |
| - """Real part of ``eps_model`` evaluated on ``self`` fields.""" |
3930 |
| - return np.real(self._eps_model(eps_inf, poles, frequency)) |
3931 |
| - |
3932 |
| - def eps_model_i( |
3933 |
| - eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float |
3934 |
| - ) -> float: |
3935 |
| - """Real part of ``eps_model`` evaluated on ``self`` fields.""" |
3936 |
| - return np.imag(self._eps_model(eps_inf, poles, frequency)) |
3937 |
| - |
3938 |
| - # compute the gradients w.r.t. each real and imaginary parts for eps_inf and poles |
3939 |
| - grad_eps_model_r = ag.elementwise_grad(eps_model_r, argnum=(0, 1)) |
3940 |
| - grad_eps_model_i = ag.elementwise_grad(eps_model_i, argnum=(0, 1)) |
3941 |
| - deps_deps_inf_r, deps_dpoles_r = grad_eps_model_r( |
3942 |
| - self.eps_inf.values, poles_complex, frequency |
3943 |
| - ) |
3944 |
| - deps_deps_inf_i, deps_dpoles_i = grad_eps_model_i( |
3945 |
| - self.eps_inf.values, poles_complex, frequency |
3946 |
| - ) |
3947 |
| - |
3948 |
| - # multiply with dJ_deps partial derivative to give full gradients |
3949 |
| - |
3950 |
| - deps_deps_inf = deps_deps_inf_r + 1j * deps_deps_inf_i |
3951 |
| - dJ_deps_inf = dJ_deps * deps_deps_inf / 3.0 # mysterious 3 |
3952 | 3921 |
|
3953 |
| - dJ_dpoles = [] |
3954 |
| - for (da_r, dc_r), (da_i, dc_i) in zip(deps_dpoles_r, deps_dpoles_i): |
3955 |
| - da = da_r + 1j * da_i |
3956 |
| - dc = dc_r + 1j * dc_i |
3957 |
| - dJ_da = dJ_deps * da / 2.0 # mysterious 2 |
3958 |
| - dJ_dc = dJ_deps * dc / 2.0 # mysterious 2 |
3959 |
| - dJ_dpoles.append((dJ_da, dJ_dc)) |
3960 |
| - |
3961 |
| - derivative_map = {} |
3962 |
| - for field_path in derivative_info.paths: |
3963 |
| - field_name, *rest = field_path |
3964 |
| - |
3965 |
| - if field_name == "eps_inf": |
3966 |
| - derivative_map[field_path] = np.real(dJ_deps_inf) |
3967 |
| - |
3968 |
| - elif field_name == "poles": |
3969 |
| - pole_index, a_or_c = rest |
3970 |
| - derivative_map[field_path] = dJ_dpoles[pole_index][a_or_c] |
3971 |
| - |
3972 |
| - return derivative_map |
| 3922 | + return PoleResidue._get_vjps_from_params( |
| 3923 | + dJ_deps_complex=dJ_deps_complex, |
| 3924 | + poles_vals=poles_vals, |
| 3925 | + omega=2 * np.pi * derivative_info.frequency, |
| 3926 | + requested_paths=derivative_info.paths, |
| 3927 | + ) |
3973 | 3928 |
|
3974 | 3929 |
|
3975 | 3930 | class Sellmeier(DispersiveMedium):
|
|
0 commit comments