Skip to content

Commit 7464cfc

Browse files
committed
fix: multifrequency adjoint performance by vectorizing over frequency
dimension
1 parent b58ae5f commit 7464cfc

File tree

10 files changed

+362
-187
lines changed

10 files changed

+362
-187
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
### Fixed
1616
- Giving opposite boundaries different names no longer causes a symmetry validator failure.
1717
- Fixed issue with parameters in `InverseDesignResult` sometimes being outside of the valid parameter range.
18+
- Fixed performance regression for multi-frequency adjoint calculations.
1819

1920
## [2.9.0rc2] - 2025-07-17
2021

tests/test_components/test_autograd.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,10 +1575,15 @@ def J(eps):
15751575

15761576
dJ_deps = ag.holomorphic_grad(J)(eps0)
15771577

1578+
# Wrap the scalar as a DataArray to match expected return type
1579+
import xarray as xr
1580+
1581+
dJ_deps_array = xr.DataArray([dJ_deps], dims=["f"], coords={"f": [freq]})
1582+
15781583
monkeypatch.setattr(
15791584
td.PoleResidue,
15801585
"_derivative_eps_complex_volume",
1581-
lambda self, E_der_map, bounds, freqs: dJ_deps,
1586+
lambda self, E_der_map, bounds: dJ_deps_array,
15821587
)
15831588

15841589
import importlib
@@ -1603,10 +1608,14 @@ def J(eps):
16031608
eps_data={},
16041609
eps_in=2.0,
16051610
eps_out=1.0,
1606-
frequency=freq,
1611+
frequencies=[freq],
16071612
bounds=((-1, -1, -1), (1, 1, 1)),
1608-
eps_no_structure=td.SpatialDataArray([[[1.0]]], coords={"x": [0], "y": [0], "z": [0]}),
1609-
eps_inf_structure=td.SpatialDataArray([[[2.0]]], coords={"x": [0], "y": [0], "z": [0]}),
1613+
eps_no_structure=td.ScalarFieldDataArray(
1614+
[[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]}
1615+
),
1616+
eps_inf_structure=td.ScalarFieldDataArray(
1617+
[[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]}
1618+
),
16101619
bounds_intersect=((-1, -1, -1), (1, 1, 1)),
16111620
)
16121621

@@ -1661,7 +1670,7 @@ def J(eps):
16611670
monkeypatch.setattr(
16621671
td.CustomPoleResidue,
16631672
"_derivative_field_cmp",
1664-
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps / 3.0,
1673+
lambda self, E_der_map, eps_data, dim: dJ_deps / 3.0,
16651674
)
16661675

16671676
import importlib
@@ -1685,10 +1694,14 @@ def J(eps):
16851694
eps_data={},
16861695
eps_in=2.0,
16871696
eps_out=1.0,
1688-
frequency=freq,
1697+
frequencies=[freq],
16891698
bounds=((-1, -1, -1), (1, 1, 1)),
1690-
eps_no_structure=td.SpatialDataArray([[[1.0]]], coords={"x": [0], "y": [0], "z": [0]}),
1691-
eps_inf_structure=td.SpatialDataArray([[[2.0]]], coords={"x": [0], "y": [0], "z": [0]}),
1699+
eps_no_structure=td.ScalarFieldDataArray(
1700+
[[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]}
1701+
),
1702+
eps_inf_structure=td.ScalarFieldDataArray(
1703+
[[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]}
1704+
),
16921705
bounds_intersect=((-1, -1, -1), (1, 1, 1)),
16931706
)
16941707

tests/test_components/test_autograd_polyslab.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
self.b = coeffs["b"]
104104
self.c = coeffs["c"]
105105
self.d = coeffs["d"]
106-
self.frequency = 200e12
106+
self.frequencies = [200e12]
107107
self.eps_in = 12.0
108108
self.interpolators = None
109109

@@ -114,6 +114,8 @@ def __init__(
114114
)
115115

116116
adaptive_vjp_spacing = DerivativeInfo.adaptive_vjp_spacing
117+
wavelength_min = property(lambda self: DerivativeInfo.wavelength_min.fget(self))
118+
wavelength_max = property(lambda self: DerivativeInfo.wavelength_max.fget(self))
117119

118120
def create_interpolators(self, dtype=None):
119121
return {}

tidy3d/components/autograd/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@
2222
MINIMUM_SPACING = 1e-2
2323

2424
EDGE_CLIP_TOLERANCE = 1e-9
25+
26+
# chunk size for processing multiple frequencies in adjoint gradient computation.
27+
# None = process all frequencies at once (no chunking)
28+
ADJOINT_FREQ_CHUNK_SIZE = None

tidy3d/components/autograd/derivative_utils.py

Lines changed: 97 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass, field, replace
6-
from typing import Callable, Optional
6+
from typing import Callable, Optional, Union
77

88
import numpy as np
99
import xarray as xr
1010

11-
from tidy3d.components.data.data_array import ScalarFieldDataArray, SpatialDataArray
12-
from tidy3d.components.types import Bound, tidycomplex
11+
from tidy3d.components.data.data_array import FreqDataArray, ScalarFieldDataArray
12+
from tidy3d.components.types import ArrayLike, Bound, tidycomplex
1313
from tidy3d.constants import C_0, LARGE_NUMBER
1414

1515
from .constants import (
@@ -23,6 +23,7 @@
2323

2424
FieldData = dict[str, ScalarFieldDataArray]
2525
PermittivityData = dict[str, ScalarFieldDataArray]
26+
EpsType = Union[tidycomplex, FreqDataArray]
2627

2728

2829
class LazyInterpolator:
@@ -90,12 +91,12 @@ class DerivativeInfo:
9091
Dataset of relative permittivity values along all three dimensions.
9192
Used for automatically computing permittivity inside or outside of a simple geometry."""
9293

93-
eps_in: tidycomplex
94+
eps_in: EpsType
9495
"""Permittivity inside the Structure.
9596
Typically computed from Structure.medium.eps_model.
9697
Used when it cannot be computed from eps_data or when eps_approx=True."""
9798

98-
eps_out: tidycomplex
99+
eps_out: EpsType
99100
"""Permittivity outside the Structure.
100101
Typically computed from Simulation.medium.eps_model.
101102
Used when it cannot be computed from eps_data or when eps_approx=True."""
@@ -109,22 +110,22 @@ class DerivativeInfo:
109110
Bounds corresponding to the minimum intersection between the structure
110111
and the simulation it is contained in."""
111112

112-
frequency: float
113-
"""Frequency of adjoint simulation at which the gradient is computed."""
113+
frequencies: ArrayLike
114+
"""Frequencies at which the adjoint gradient should be computed."""
114115

115116
# Optional fields with defaults
116-
eps_background: Optional[tidycomplex] = None
117+
eps_background: Optional[EpsType] = None
117118
"""Permittivity in background.
118119
Permittivity outside of the Structure as manually specified by
119120
Structure.background_medium."""
120121

121-
eps_no_structure: Optional[SpatialDataArray] = None
122+
eps_no_structure: Optional[ScalarFieldDataArray] = None
122123
"""Permittivity without structure.
123124
The permittivity of the original simulation without the structure that is
124125
being differentiated with respect to. Used to approximate permittivity
125126
outside of the structure for shape optimization."""
126127

127-
eps_inf_structure: Optional[SpatialDataArray] = None
128+
eps_inf_structure: Optional[ScalarFieldDataArray] = None
128129
"""Permittivity with infinite structure.
129130
The permittivity of the original simulation where the structure being
130131
differentiated with respect to is infinitely large. Used to approximate
@@ -153,19 +154,10 @@ def updated_copy(self, **kwargs):
153154
kwargs.pop("validate", None)
154155
return replace(self, **kwargs)
155156

156-
@staticmethod
157-
def _get_freq_index(arr: ScalarFieldDataArray, freq: float) -> int:
158-
"""Get the index of the frequency in the array's frequency coordinates."""
159-
if "f" not in arr.dims:
160-
return None
161-
freq_coords = arr.coords["f"].data
162-
idx = np.argmin(np.abs(freq_coords - freq))
163-
return int(idx)
164-
165157
@staticmethod
166158
def _nan_to_num_if_needed(coords: np.ndarray) -> np.ndarray:
167159
"""Convert NaN and infinite values to finite numbers, optimized for finite inputs."""
168-
# skip check for small arrays - overhead exceeds benefit
160+
# skip check for small arrays
169161
if coords.size < 1000:
170162
return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER)
171163

@@ -238,24 +230,50 @@ def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=Tru
238230
coord_cache[arr_id] = points
239231
points = coord_cache[arr_id]
240232

241-
# defer data selection until the interpolator is called
242233
def creator_func(arr=arr, points=points):
243-
freq_idx = self._get_freq_index(arr, self.frequency)
244-
data = arr.data if freq_idx is None else arr.isel(f=freq_idx).data
245-
data = data.astype(
246-
GRADIENT_DTYPE_COMPLEX if np.iscomplexobj(data) else dtype, copy=False
234+
data = arr.data.astype(
235+
GRADIENT_DTYPE_COMPLEX if np.iscomplexobj(arr.data) else dtype, copy=False
247236
)
248-
return RegularGridInterpolator(
249-
points, data, method="linear", bounds_error=False, fill_value=None
237+
238+
# create interpolator with frequency dimension
239+
if "f" in arr.dims:
240+
freq_coords = arr.coords["f"].data.astype(dtype, copy=False)
241+
# ensure frequency dimension is last
242+
if arr.dims != ("x", "y", "z", "f"):
243+
freq_dim_idx = arr.dims.index("f")
244+
axes = list(range(data.ndim))
245+
axes.append(axes.pop(freq_dim_idx))
246+
data = np.transpose(data, axes)
247+
else:
248+
# single frequency case - add singleton dimension
249+
freq_coords = np.array([0.0], dtype=dtype)
250+
data = data[..., np.newaxis]
251+
252+
points_with_freq = (*points, freq_coords)
253+
interpolator_obj = RegularGridInterpolator(
254+
points_with_freq, data, method="linear", bounds_error=False, fill_value=None
250255
)
251256

257+
def interpolator(coords):
258+
# coords: (N, 3) spatial points
259+
n_points = coords.shape[0]
260+
n_freqs = len(freq_coords)
261+
262+
# build coordinates with frequency dimension
263+
coords_with_freq = np.empty((n_points * n_freqs, 4), dtype=coords.dtype)
264+
coords_with_freq[:, :3] = np.repeat(coords, n_freqs, axis=0)
265+
coords_with_freq[:, 3] = np.tile(freq_coords, n_points)
266+
267+
result = interpolator_obj(coords_with_freq)
268+
return result.reshape(n_points, n_freqs)
269+
270+
return interpolator
271+
252272
if is_field_group:
253273
interpolators[group_key][component_name] = LazyInterpolator(creator_func)
254274
else:
255-
# for permittivity, store directly with the key (not nested)
256275
interpolators[component_name] = LazyInterpolator(creator_func)
257276

258-
# process field interpolators (nested dictionaries)
259277
for group_key, data_dict in [
260278
("E_fwd", self.E_fwd),
261279
("E_adj", self.E_adj),
@@ -264,7 +282,6 @@ def creator_func(arr=arr, points=points):
264282
]:
265283
_make_lazy_interpolator_group(data_dict, group_key, is_field_group=True)
266284

267-
# process permittivity interpolators
268285
if self.eps_inf_structure is not None:
269286
_make_lazy_interpolator_group(
270287
{"eps_inf": self.eps_inf_structure}, None, is_field_group=False
@@ -339,31 +356,50 @@ def evaluate_gradient_at_points(
339356
E_fwd_perp2 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps2)
340357
E_adj_perp2 = self._project_in_basis(E_adj_at_coords, basis_vector=perps2)
341358

342-
# compute field products
343359
D_der_norm = D_fwd_norm * D_adj_norm
344360
E_der_perp1 = E_fwd_perp1 * E_adj_perp1
345361
E_der_perp2 = E_fwd_perp2 * E_adj_perp2
346362

347-
# get permittivity jumps across interface
348363
if "eps_inf" in interpolators:
349364
eps_in = interpolators["eps_inf"](spatial_coords)
350365
else:
351-
eps_in = self.eps_in
366+
eps_in = self._prepare_epsilon(self.eps_in)
352367

353368
if "eps_no" in interpolators:
354369
eps_out = interpolators["eps_no"](spatial_coords)
355-
elif self.eps_background is not None:
356-
eps_out = self.eps_background
357370
else:
358-
eps_out = self.eps_out
371+
# use eps_background if available, otherwise use eps_out
372+
eps_to_prepare = (
373+
self.eps_background if self.eps_background is not None else self.eps_out
374+
)
375+
eps_out = self._prepare_epsilon(eps_to_prepare)
359376

360377
delta_eps_inv = 1.0 / eps_in - 1.0 / eps_out
361378
delta_eps = eps_in - eps_out
362379

363380
vjps = -delta_eps_inv * D_der_norm + E_der_perp1 * delta_eps + E_der_perp2 * delta_eps
364381

382+
# sum over frequency dimension
383+
vjps = np.sum(vjps, axis=-1)
384+
365385
return vjps
366386

387+
@staticmethod
388+
def _prepare_epsilon(eps: EpsType) -> np.ndarray:
389+
"""Prepare epsilon values for multi-frequency.
390+
391+
For FreqDataArray, extracts values and broadcasts to shape (1, n_freqs).
392+
For scalar values, broadcasts to shape (1, 1) for consistency with multi-frequency.
393+
"""
394+
if isinstance(eps, FreqDataArray):
395+
# data is already sliced, just extract values
396+
eps_values = eps.values
397+
# shape: (n_freqs,) - need to broadcast to (1, n_freqs)
398+
return eps_values[np.newaxis, :]
399+
else:
400+
# scalar value - broadcast to (1, 1)
401+
return np.array([[eps]])
402+
367403
@staticmethod
368404
def _project_in_basis(
369405
field_components: dict[str, np.ndarray],
@@ -375,17 +411,21 @@ def _project_in_basis(
375411
----------
376412
field_components : dict[str, np.ndarray]
377413
Dictionary with keys like "Ex", "Ey", "Ez" or "Dx", "Dy", "Dz" containing field values.
414+
Values have shape (N, F) where F is the number of frequencies.
378415
basis_vector : np.ndarray
379416
(N, 3) array of basis vectors, one per evaluation point.
380417
381418
Returns
382419
-------
383420
np.ndarray
384-
(N,) array of projected field values.
421+
Projected field values with shape (N, F).
385422
"""
386423
prefix = next(iter(field_components.keys()))[0]
387-
field_matrix = np.stack([field_components[f"{prefix}{dim}"] for dim in "xyz"], axis=1)
388-
return np.einsum("ij,ij->i", field_matrix, basis_vector)
424+
field_matrix = np.stack([field_components[f"{prefix}{dim}"] for dim in "xyz"], axis=0)
425+
426+
# always expect (3, N, F) shape, transpose to (N, 3, F)
427+
field_matrix = np.transpose(field_matrix, (1, 0, 2))
428+
return np.einsum("ij...,ij->i...", field_matrix, basis_vector)
389429

390430
def adaptive_vjp_spacing(
391431
self,
@@ -409,25 +449,38 @@ def adaptive_vjp_spacing(
409449
float
410450
Adaptive spacing value for gradient evaluation.
411451
"""
412-
eps_real = np.asarray(self.eps_in, dtype=np.complex128).real
452+
# handle FreqDataArray or scalar eps_in
453+
if isinstance(self.eps_in, FreqDataArray):
454+
eps_real = np.asarray(self.eps_in.values, dtype=np.complex128).real
455+
else:
456+
eps_real = np.asarray(self.eps_in, dtype=np.complex128).real
413457

414458
dx_candidates = []
459+
max_frequency = np.max(self.frequencies)
415460

416-
# wavelength-based sampling for dielectric materials
461+
# wavelength-based sampling for dielectrics
417462
if np.any(eps_real > 0):
418463
eps_max = eps_real[eps_real > 0].max()
419-
lambda_min = C_0 / (self.frequency * np.sqrt(eps_max))
464+
lambda_min = self.wavelength_min / np.sqrt(eps_max)
420465
dx_candidates.append(wl_fraction * lambda_min)
421466

422-
# skin depth-based sampling for metallic materials
467+
# skin depth sampling for metals
423468
if np.any(eps_real <= 0):
424-
omega = 2 * np.pi * self.frequency
469+
omega = 2 * np.pi * max_frequency
425470
eps_neg = eps_real[eps_real <= 0]
426471
delta_min = C_0 / (omega * np.sqrt(np.abs(eps_neg).max()))
427472
dx_candidates.append(wl_fraction * delta_min)
428473

429474
return max(min(dx_candidates), min_allowed_spacing)
430475

476+
@property
477+
def wavelength_min(self) -> float:
478+
return C_0 / np.max(self.frequencies)
479+
480+
@property
481+
def wavelength_max(self) -> float:
482+
return C_0 / np.min(self.frequencies)
483+
431484

432485
def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray:
433486
"""Integrate a data array within specified spatial bounds.

tidy3d/components/data/monitor_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3448,7 +3448,7 @@ def adjoint_source_amp(self, amp: DataArray, fwidth: float) -> PlaneWave:
34483448
k0 = 2 * np.pi * freq0 / C_0
34493449
bck_eps = self.medium.eps_model(freq0)
34503450
grad_const = 0.5 * k0 / np.sqrt(bck_eps) * np.cos(angle_theta)
3451-
src_amp = grad_const * amp_complex
3451+
src_amp = 1j * grad_const * amp_complex
34523452

34533453
# construct plane wave source
34543454
adj_src = PlaneWave(

0 commit comments

Comments
 (0)