3
3
from __future__ import annotations
4
4
5
5
from dataclasses import dataclass , field , replace
6
- from typing import Callable , Optional
6
+ from typing import Callable , Optional , Union
7
7
8
8
import numpy as np
9
9
import xarray as xr
10
10
11
- from tidy3d .components .data .data_array import ScalarFieldDataArray , SpatialDataArray
12
- from tidy3d .components .types import Bound , tidycomplex
11
+ from tidy3d .components .data .data_array import FreqDataArray , ScalarFieldDataArray
12
+ from tidy3d .components .types import ArrayLike , Bound , tidycomplex
13
13
from tidy3d .constants import C_0 , LARGE_NUMBER
14
14
15
15
from .constants import (
23
23
24
24
FieldData = dict [str , ScalarFieldDataArray ]
25
25
PermittivityData = dict [str , ScalarFieldDataArray ]
26
+ EpsType = Union [tidycomplex , FreqDataArray ]
26
27
27
28
28
29
class LazyInterpolator :
@@ -90,12 +91,12 @@ class DerivativeInfo:
90
91
Dataset of relative permittivity values along all three dimensions.
91
92
Used for automatically computing permittivity inside or outside of a simple geometry."""
92
93
93
- eps_in : tidycomplex
94
+ eps_in : EpsType
94
95
"""Permittivity inside the Structure.
95
96
Typically computed from Structure.medium.eps_model.
96
97
Used when it cannot be computed from eps_data or when eps_approx=True."""
97
98
98
- eps_out : tidycomplex
99
+ eps_out : EpsType
99
100
"""Permittivity outside the Structure.
100
101
Typically computed from Simulation.medium.eps_model.
101
102
Used when it cannot be computed from eps_data or when eps_approx=True."""
@@ -109,22 +110,22 @@ class DerivativeInfo:
109
110
Bounds corresponding to the minimum intersection between the structure
110
111
and the simulation it is contained in."""
111
112
112
- frequency : float
113
- """Frequency of adjoint simulation at which the gradient is computed."""
113
+ frequencies : ArrayLike
114
+ """Frequencies at which the adjoint gradient should be computed."""
114
115
115
116
# Optional fields with defaults
116
- eps_background : Optional [tidycomplex ] = None
117
+ eps_background : Optional [EpsType ] = None
117
118
"""Permittivity in background.
118
119
Permittivity outside of the Structure as manually specified by
119
120
Structure.background_medium."""
120
121
121
- eps_no_structure : Optional [SpatialDataArray ] = None
122
+ eps_no_structure : Optional [ScalarFieldDataArray ] = None
122
123
"""Permittivity without structure.
123
124
The permittivity of the original simulation without the structure that is
124
125
being differentiated with respect to. Used to approximate permittivity
125
126
outside of the structure for shape optimization."""
126
127
127
- eps_inf_structure : Optional [SpatialDataArray ] = None
128
+ eps_inf_structure : Optional [ScalarFieldDataArray ] = None
128
129
"""Permittivity with infinite structure.
129
130
The permittivity of the original simulation where the structure being
130
131
differentiated with respect to is infinitely large. Used to approximate
@@ -153,19 +154,10 @@ def updated_copy(self, **kwargs):
153
154
kwargs .pop ("validate" , None )
154
155
return replace (self , ** kwargs )
155
156
156
- @staticmethod
157
- def _get_freq_index (arr : ScalarFieldDataArray , freq : float ) -> int :
158
- """Get the index of the frequency in the array's frequency coordinates."""
159
- if "f" not in arr .dims :
160
- return None
161
- freq_coords = arr .coords ["f" ].data
162
- idx = np .argmin (np .abs (freq_coords - freq ))
163
- return int (idx )
164
-
165
157
@staticmethod
166
158
def _nan_to_num_if_needed (coords : np .ndarray ) -> np .ndarray :
167
159
"""Convert NaN and infinite values to finite numbers, optimized for finite inputs."""
168
- # skip check for small arrays - overhead exceeds benefit
160
+ # skip check for small arrays
169
161
if coords .size < 1000 :
170
162
return np .nan_to_num (coords , posinf = LARGE_NUMBER , neginf = - LARGE_NUMBER )
171
163
@@ -238,24 +230,50 @@ def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=Tru
238
230
coord_cache [arr_id ] = points
239
231
points = coord_cache [arr_id ]
240
232
241
- # defer data selection until the interpolator is called
242
233
def creator_func (arr = arr , points = points ):
243
- freq_idx = self ._get_freq_index (arr , self .frequency )
244
- data = arr .data if freq_idx is None else arr .isel (f = freq_idx ).data
245
- data = data .astype (
246
- GRADIENT_DTYPE_COMPLEX if np .iscomplexobj (data ) else dtype , copy = False
234
+ data = arr .data .astype (
235
+ GRADIENT_DTYPE_COMPLEX if np .iscomplexobj (arr .data ) else dtype , copy = False
247
236
)
248
- return RegularGridInterpolator (
249
- points , data , method = "linear" , bounds_error = False , fill_value = None
237
+
238
+ # create interpolator with frequency dimension
239
+ if "f" in arr .dims :
240
+ freq_coords = arr .coords ["f" ].data .astype (dtype , copy = False )
241
+ # ensure frequency dimension is last
242
+ if arr .dims != ("x" , "y" , "z" , "f" ):
243
+ freq_dim_idx = arr .dims .index ("f" )
244
+ axes = list (range (data .ndim ))
245
+ axes .append (axes .pop (freq_dim_idx ))
246
+ data = np .transpose (data , axes )
247
+ else :
248
+ # single frequency case - add singleton dimension
249
+ freq_coords = np .array ([0.0 ], dtype = dtype )
250
+ data = data [..., np .newaxis ]
251
+
252
+ points_with_freq = (* points , freq_coords )
253
+ interpolator_obj = RegularGridInterpolator (
254
+ points_with_freq , data , method = "linear" , bounds_error = False , fill_value = None
250
255
)
251
256
257
+ def interpolator (coords ):
258
+ # coords: (N, 3) spatial points
259
+ n_points = coords .shape [0 ]
260
+ n_freqs = len (freq_coords )
261
+
262
+ # build coordinates with frequency dimension
263
+ coords_with_freq = np .empty ((n_points * n_freqs , 4 ), dtype = coords .dtype )
264
+ coords_with_freq [:, :3 ] = np .repeat (coords , n_freqs , axis = 0 )
265
+ coords_with_freq [:, 3 ] = np .tile (freq_coords , n_points )
266
+
267
+ result = interpolator_obj (coords_with_freq )
268
+ return result .reshape (n_points , n_freqs )
269
+
270
+ return interpolator
271
+
252
272
if is_field_group :
253
273
interpolators [group_key ][component_name ] = LazyInterpolator (creator_func )
254
274
else :
255
- # for permittivity, store directly with the key (not nested)
256
275
interpolators [component_name ] = LazyInterpolator (creator_func )
257
276
258
- # process field interpolators (nested dictionaries)
259
277
for group_key , data_dict in [
260
278
("E_fwd" , self .E_fwd ),
261
279
("E_adj" , self .E_adj ),
@@ -264,7 +282,6 @@ def creator_func(arr=arr, points=points):
264
282
]:
265
283
_make_lazy_interpolator_group (data_dict , group_key , is_field_group = True )
266
284
267
- # process permittivity interpolators
268
285
if self .eps_inf_structure is not None :
269
286
_make_lazy_interpolator_group (
270
287
{"eps_inf" : self .eps_inf_structure }, None , is_field_group = False
@@ -339,31 +356,50 @@ def evaluate_gradient_at_points(
339
356
E_fwd_perp2 = self ._project_in_basis (E_fwd_at_coords , basis_vector = perps2 )
340
357
E_adj_perp2 = self ._project_in_basis (E_adj_at_coords , basis_vector = perps2 )
341
358
342
- # compute field products
343
359
D_der_norm = D_fwd_norm * D_adj_norm
344
360
E_der_perp1 = E_fwd_perp1 * E_adj_perp1
345
361
E_der_perp2 = E_fwd_perp2 * E_adj_perp2
346
362
347
- # get permittivity jumps across interface
348
363
if "eps_inf" in interpolators :
349
364
eps_in = interpolators ["eps_inf" ](spatial_coords )
350
365
else :
351
- eps_in = self .eps_in
366
+ eps_in = self ._prepare_epsilon ( self . eps_in )
352
367
353
368
if "eps_no" in interpolators :
354
369
eps_out = interpolators ["eps_no" ](spatial_coords )
355
- elif self .eps_background is not None :
356
- eps_out = self .eps_background
357
370
else :
358
- eps_out = self .eps_out
371
+ # use eps_background if available, otherwise use eps_out
372
+ eps_to_prepare = (
373
+ self .eps_background if self .eps_background is not None else self .eps_out
374
+ )
375
+ eps_out = self ._prepare_epsilon (eps_to_prepare )
359
376
360
377
delta_eps_inv = 1.0 / eps_in - 1.0 / eps_out
361
378
delta_eps = eps_in - eps_out
362
379
363
380
vjps = - delta_eps_inv * D_der_norm + E_der_perp1 * delta_eps + E_der_perp2 * delta_eps
364
381
382
+ # sum over frequency dimension
383
+ vjps = np .sum (vjps , axis = - 1 )
384
+
365
385
return vjps
366
386
387
+ @staticmethod
388
+ def _prepare_epsilon (eps : EpsType ) -> np .ndarray :
389
+ """Prepare epsilon values for multi-frequency.
390
+
391
+ For FreqDataArray, extracts values and broadcasts to shape (1, n_freqs).
392
+ For scalar values, broadcasts to shape (1, 1) for consistency with multi-frequency.
393
+ """
394
+ if isinstance (eps , FreqDataArray ):
395
+ # data is already sliced, just extract values
396
+ eps_values = eps .values
397
+ # shape: (n_freqs,) - need to broadcast to (1, n_freqs)
398
+ return eps_values [np .newaxis , :]
399
+ else :
400
+ # scalar value - broadcast to (1, 1)
401
+ return np .array ([[eps ]])
402
+
367
403
@staticmethod
368
404
def _project_in_basis (
369
405
field_components : dict [str , np .ndarray ],
@@ -375,17 +411,21 @@ def _project_in_basis(
375
411
----------
376
412
field_components : dict[str, np.ndarray]
377
413
Dictionary with keys like "Ex", "Ey", "Ez" or "Dx", "Dy", "Dz" containing field values.
414
+ Values have shape (N, F) where F is the number of frequencies.
378
415
basis_vector : np.ndarray
379
416
(N, 3) array of basis vectors, one per evaluation point.
380
417
381
418
Returns
382
419
-------
383
420
np.ndarray
384
- (N,) array of projected field values .
421
+ Projected field values with shape (N, F) .
385
422
"""
386
423
prefix = next (iter (field_components .keys ()))[0 ]
387
- field_matrix = np .stack ([field_components [f"{ prefix } { dim } " ] for dim in "xyz" ], axis = 1 )
388
- return np .einsum ("ij,ij->i" , field_matrix , basis_vector )
424
+ field_matrix = np .stack ([field_components [f"{ prefix } { dim } " ] for dim in "xyz" ], axis = 0 )
425
+
426
+ # always expect (3, N, F) shape, transpose to (N, 3, F)
427
+ field_matrix = np .transpose (field_matrix , (1 , 0 , 2 ))
428
+ return np .einsum ("ij...,ij->i..." , field_matrix , basis_vector )
389
429
390
430
def adaptive_vjp_spacing (
391
431
self ,
@@ -409,25 +449,38 @@ def adaptive_vjp_spacing(
409
449
float
410
450
Adaptive spacing value for gradient evaluation.
411
451
"""
412
- eps_real = np .asarray (self .eps_in , dtype = np .complex128 ).real
452
+ # handle FreqDataArray or scalar eps_in
453
+ if isinstance (self .eps_in , FreqDataArray ):
454
+ eps_real = np .asarray (self .eps_in .values , dtype = np .complex128 ).real
455
+ else :
456
+ eps_real = np .asarray (self .eps_in , dtype = np .complex128 ).real
413
457
414
458
dx_candidates = []
459
+ max_frequency = np .max (self .frequencies )
415
460
416
- # wavelength-based sampling for dielectric materials
461
+ # wavelength-based sampling for dielectrics
417
462
if np .any (eps_real > 0 ):
418
463
eps_max = eps_real [eps_real > 0 ].max ()
419
- lambda_min = C_0 / ( self .frequency * np .sqrt (eps_max ) )
464
+ lambda_min = self .wavelength_min / np .sqrt (eps_max )
420
465
dx_candidates .append (wl_fraction * lambda_min )
421
466
422
- # skin depth-based sampling for metallic materials
467
+ # skin depth sampling for metals
423
468
if np .any (eps_real <= 0 ):
424
- omega = 2 * np .pi * self . frequency
469
+ omega = 2 * np .pi * max_frequency
425
470
eps_neg = eps_real [eps_real <= 0 ]
426
471
delta_min = C_0 / (omega * np .sqrt (np .abs (eps_neg ).max ()))
427
472
dx_candidates .append (wl_fraction * delta_min )
428
473
429
474
return max (min (dx_candidates ), min_allowed_spacing )
430
475
476
+ @property
477
+ def wavelength_min (self ) -> float :
478
+ return C_0 / np .max (self .frequencies )
479
+
480
+ @property
481
+ def wavelength_max (self ) -> float :
482
+ return C_0 / np .min (self .frequencies )
483
+
431
484
432
485
def integrate_within_bounds (arr : xr .DataArray , dims : list [str ], bounds : Bound ) -> xr .DataArray :
433
486
"""Integrate a data array within specified spatial bounds.
0 commit comments