Skip to content

Commit 9a9f1f5

Browse files
committed
differentiable s-matrix calculation
1 parent 3162a00 commit 9a9f1f5

File tree

8 files changed

+137
-25
lines changed

8 files changed

+137
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- Selective simulation capabilities to `TerminalComponentModeler` via `run_only` and `element_mappings` fields, allowing users to run fewer simulations and extract only needed scattering matrix elements.
1212
- Added KLayout plugin, with DRC functionality for running design rule checks in `plugins.klayout.drc`. Supports running DRC on GDS files as well as `Geometry`, `Structure`, and `Simulation` objects.
1313
- Added "mil" and "in" (inch) units to `plot_length_units`.
14+
- Objective functions that involve running `tidy3d.plugins.smatrix.ComponentModeler` can be differentiated with autograd.
1415

1516
### Changed
1617
- Validate mode solver object for large number of grid points on the modal plane.

tests/test_data/test_data_arrays.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
from typing import Optional
66

7-
import numpy as np
7+
import autograd as ag
8+
import autograd.numpy as np
89
import pytest
910
import xarray.testing as xrt
11+
from autograd.test_util import check_grads
1012

1113
import tidy3d as td
1214
from tidy3d.exceptions import DataError
@@ -468,3 +470,49 @@ def test_interp(method, scalar_index):
468470
xr_interp = data.interp(f=f)
469471
ag_interp = data._ag_interp(f=f)
470472
xrt.assert_allclose(xr_interp, ag_interp)
473+
474+
475+
def test_with_updated_data_grad():
476+
"""Check the ``DataArray.with_updated_data()`` method."""
477+
478+
arr = td.SpatialDataArray(
479+
np.ones((2, 3, 4, 5), dtype=np.complex128),
480+
coords={"x": [0, 1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "w": [0, 1, 2, 3, 4]},
481+
)
482+
483+
data = np.zeros((1, 1, 1, 5))
484+
485+
coords = {"x": 0, "y": 2, "z": 3}
486+
487+
arr2 = arr._with_updated_data(data=data, coords=coords)
488+
489+
data_expected = np.ones(arr.shape) + 0j
490+
data_expected[0, 1, 1, :] = 0.0 + 0j
491+
assert np.all(arr2.data == data_expected), "DataArray.with_updated_copy() failed"
492+
493+
def f(x):
494+
arr2 = arr._with_updated_data(data=x, coords=coords)
495+
return np.abs(np.sum(arr2.data))
496+
497+
# grad should just be all 1s because of sum, so check that this is true
498+
g = ag.grad(f)(data)
499+
assert np.all(g == np.ones_like(data))
500+
501+
check_grads(f, order=1, modes=["rev"])(data)
502+
503+
504+
def test_with_updated_data_shape():
505+
"""Check the ``DataArray.with_updated_data()`` method."""
506+
507+
arr = td.SpatialDataArray(
508+
np.ones((2, 3, 4, 5), dtype=np.complex128),
509+
coords={"x": [0, 1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "w": [0, 1, 2, 3, 4]},
510+
)
511+
512+
# wrong shape
513+
data = np.zeros((1, 1, 1, 3))
514+
515+
coords = {"x": 0, "y": 2, "z": 3}
516+
517+
with pytest.raises(ValueError):
518+
arr2 = arr._with_updated_data(data=data, coords=coords)

tidy3d/components/data/data_array.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,40 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs):
489489
result = result.transpose(*out_dims)
490490
return result
491491

492+
def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray:
493+
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible
494+
495+
Constraints / Edge cases:
496+
- `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays
497+
- `data` will be reshaped to try to match `self.shape` except where `coords` present
498+
"""
499+
500+
# make mask
501+
mask = xr.zeros_like(self, dtype=bool)
502+
mask.loc[coords] = True
503+
504+
# reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis
505+
old_data = self.data
506+
new_shape = list(old_data.shape)
507+
for i, dim in enumerate(self.dims):
508+
if dim in coords:
509+
new_shape[i] = 1
510+
try:
511+
new_data = data.reshape(new_shape)
512+
except ValueError as e:
513+
raise ValueError(
514+
"Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was "
515+
f"of shape {data.shape} and tried to reshape to {new_shape}. If you encounter this "
516+
"error please raise an issue on the tidy3d github repository with the context."
517+
) from e
518+
519+
# broadcast data to repeat data along the selected dimensions to match mask
520+
new_data = new_data + np.zeros_like(old_data)
521+
522+
new_data = np.where(mask, new_data, old_data)
523+
524+
return self.copy(deep=True, data=new_data)
525+
492526

493527
class FreqDataArray(DataArray):
494528
"""Frequency-domain array.

tidy3d/plugins/autograd/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ We also support the following high-level features:
217217
- We automatically determine the number of adjoint simulations to run from a given forward simulation to maintain gradient accuracy.
218218
Adjoint sources are automatically grouped by either frequency or spatial port (whichever yields fewer adjoint simulations), and all adjoint simulations are run in a single batch (applies to both `run` and `run_async`).
219219
The parameter `max_num_adjoint_per_fwd` (default `10`) prevents launching unexpectedly large numbers of adjoint simulations automatically.
220+
- Differentiation of objective functions involving the scattering matrix produced by `tidy3d.plugins.smatrix.ComponentModeler.run()` and `tidy3d.plugins.smatrix.TerminalComponentModeler.run()`.
220221

221222
We currently have the following restrictions:
222223

tidy3d/plugins/smatrix/component_modelers/base.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from abc import ABC, abstractmethod
77
from typing import Generic, Optional, TypeVar, Union, get_args
88

9-
import numpy as np
9+
import autograd.numpy as np
1010
import pydantic.v1 as pd
1111

1212
from tidy3d.components.base import Tidy3dBaseModel, cached_property
@@ -23,12 +23,16 @@
2323
from tidy3d.plugins.smatrix.ports.modal import Port
2424
from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort
2525
from tidy3d.plugins.smatrix.ports.wave import WavePort
26+
from tidy3d.web import run_async
2627
from tidy3d.web.api.container import Batch, BatchData
2728

2829
# fwidth of gaussian pulse in units of central frequency
2930
FWIDTH_FRAC = 1.0 / 10
3031
DEFAULT_DATA_DIR = "."
3132

33+
# whether to run gradient calculation for component modeler locally
34+
LOCAL_GRADIENT = False
35+
3236
LumpedPortType = Union[LumpedPort, CoaxialLumpedPort]
3337
TerminalPortType = Union[LumpedPortType, WavePort]
3438

@@ -251,7 +255,26 @@ def batch_path(self) -> str:
251255
@cached_property
252256
def batch_data(self) -> BatchData:
253257
"""The :class:`.BatchData` associated with the simulations run for this component modeler."""
254-
return self.batch.run(path_dir=self.path_dir)
258+
259+
# NOTE: uses run_async because Batch is not differentiable.
260+
batch = self.batch
261+
run_async_kwargs = batch.dict(
262+
exclude={
263+
"type",
264+
"path_dir",
265+
"attrs",
266+
"solver_version",
267+
"jobs_cached",
268+
"num_workers",
269+
"simulations",
270+
}
271+
)
272+
return run_async(
273+
batch.simulations,
274+
**run_async_kwargs,
275+
local_gradient=LOCAL_GRADIENT,
276+
path_dir=self.path_dir,
277+
)
255278

256279
def get_path_dir(self, path_dir: str) -> None:
257280
"""Check whether the supplied 'path_dir' matches the internal field value."""

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Optional
88

9-
import numpy as np
9+
import autograd.numpy as np
1010
import pydantic.v1 as pd
1111

1212
from tidy3d.components.base import cached_property
@@ -237,14 +237,15 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
237237
)
238238
source_norm = self._normalization_factor(port_in, sim_data)
239239
s_matrix_elements = np.array(amp.data) / np.array(source_norm)
240-
s_matrix.loc[
241-
{
242-
"port_in": port_name_in,
243-
"mode_index_in": mode_index_in,
244-
"port_out": port_name_out,
245-
"mode_index_out": mode_index_out,
246-
}
247-
] = s_matrix_elements
240+
241+
coords_set = {
242+
"port_in": port_name_in,
243+
"mode_index_in": mode_index_in,
244+
"port_out": port_name_out,
245+
"mode_index_out": mode_index_out,
246+
}
247+
248+
s_matrix = s_matrix._with_updated_data(data=s_matrix_elements, coords=coords_set)
248249

249250
# element can be determined by user-defined mapping
250251
for (row_in, col_in), (row_out, col_out), mult_by in self.element_mappings:
@@ -259,12 +260,14 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
259260

260261
port_out_to, mode_index_out_to = row_out
261262
port_in_to, mode_index_in_to = col_out
263+
264+
elements_from = mult_by * s_matrix.loc[coords_from].values
262265
coords_to = {
263266
"port_in": port_in_to,
264267
"mode_index_in": mode_index_in_to,
265268
"port_out": port_out_to,
266269
"mode_index_out": mode_index_out_to,
267270
}
268-
s_matrix.loc[coords_to] = mult_by * s_matrix.loc[coords_from].values
271+
s_matrix = s_matrix._with_updated_data(data=elements_from, coords=coords_to)
269272

270273
return s_matrix

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,10 @@ def _internal_construct_smatrix(
291291
port, mode_index = self.network_dict[source_index]
292292
sim_data = batch_data[self._task_name(port=port, mode_index=mode_index)]
293293
a, b = self.compute_power_wave_amplitudes_at_each_port(port_impedances, sim_data)
294-
indexer = {"f": a.f, "port_out": a.port, "port_in": source_index}
295-
a_matrix.loc[indexer] = a
296-
b_matrix.loc[indexer] = b
294+
295+
indexer = {"port_in": source_index}
296+
a_matrix = a_matrix._with_updated_data(data=a.data, coords=indexer)
297+
b_matrix = b_matrix._with_updated_data(data=b.data, coords=indexer)
297298

298299
# If excitation is assumed ideal, a_matrix is assumed to be diagonal
299300
# and the explicit inverse can be avoided. When only a subset of excitations
@@ -315,7 +316,8 @@ def _internal_construct_smatrix(
315316
"port_in": col_out,
316317
"port_out": row_out,
317318
}
318-
s_matrix.loc[coords_to] = mult_by * s_matrix.loc[coords_from].values
319+
data = mult_by * s_matrix.loc[coords_from].data
320+
s_matrix = s_matrix._with_updated_data(data=data, coords=coords_to)
319321

320322
return s_matrix
321323

@@ -408,8 +410,8 @@ def compute_power_wave_amplitudes_at_each_port(
408410
port, mode_index = self.network_dict[network_index]
409411
V_out, I_out = self.compute_port_VI(port, sim_data)
410412
indexer = {"port": network_index}
411-
V_matrix.loc[indexer] = V_out
412-
I_matrix.loc[indexer] = I_out
413+
V_matrix = V_matrix._with_updated_data(data=V_out.data, coords=indexer)
414+
I_matrix = I_matrix._with_updated_data(data=I_out.data, coords=indexer)
413415

414416
V_numpy = V_matrix.values
415417
I_numpy = I_matrix.values
@@ -577,16 +579,16 @@ def _port_reference_impedances(self, batch_data: BatchData) -> PortDataArray:
577579
)
578580
for network_index in self.matrix_indices_monitor:
579581
port, mode_index = self.network_dict[network_index]
582+
indexer = {"port": network_index}
580583
if isinstance(port, WavePort):
581584
# WavePorts have a port impedance calculated from its associated modal field distribution
582585
# and is frequency dependent.
583-
impedances = port.compute_port_impedance(sim_data).values
584-
port_impedances.loc[{"port": network_index}] = impedances.squeeze()
586+
data = port.compute_port_impedance(sim_data).data
587+
port_impedances = port_impedances._with_updated_data(data=data, coords=indexer)
585588
else:
586589
# LumpedPorts have a constant reference impedance
587-
port_impedances.loc[{"port": network_index}] = np.full(
588-
len(self.freqs), port.impedance
589-
)
590+
data = np.full(len(self.freqs), port.impedance)
591+
port_impedances = port_impedances._with_updated_data(data=data, coords=indexer)
590592

591593
port_impedances = TerminalComponentModeler._set_port_data_array_attributes(port_impedances)
592594
return port_impedances

0 commit comments

Comments
 (0)