Skip to content

Commit 90948e6

Browse files
committed
add selective simulation capabilities to TerminalComponentModeler
refactored common code into AbstractComponentModeler add toggle for controlling how s matrix is calculated added validation of run_only and element_mappings for modelers
1 parent eaf1f0a commit 90948e6

File tree

7 files changed

+351
-150
lines changed

7 files changed

+351
-150
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111
- Access field decay values in `SimulationData` via `sim_data.field_decay` as `TimeDataArray`.
12+
- Selective simulation capabilities to `TerminalComponentModeler` via `run_only` and `element_mappings` fields, allowing users to run fewer simulations and extract only needed scattering matrix elements.
1213

1314
### Changed
1415
- By default, batch downloads will skip files that already exist locally. To force re-downloading and replace existing files, pass the `replace_existing=True` argument to `Batch.load()`, `Batch.download()`, or `BatchData.load()`.

tests/test_plugins/smatrix/terminal_component_modeler_def.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def make_port(center, direction, type, name) -> Union[CoaxialLumpedPort, WavePor
276276
inner_diameter=2 * Rinner,
277277
normal_axis=2,
278278
direction=direction,
279-
name=name,
279+
name="coax" + name,
280280
num_grid_cells=port_cells,
281281
impedance=reference_impedance,
282282
)
@@ -311,7 +311,7 @@ def make_port(center, direction, type, name) -> Union[CoaxialLumpedPort, WavePor
311311
center=center,
312312
size=[2 * Router, 2 * Router, 0],
313313
direction=direction,
314-
name=name,
314+
name="wave" + name,
315315
mode_spec=td.ModeSpec(num_modes=1),
316316
mode_index=0,
317317
voltage_integral=voltage_integral,
@@ -321,9 +321,9 @@ def make_port(center, direction, type, name) -> Union[CoaxialLumpedPort, WavePor
321321
return port
322322

323323
center_src1 = [0, 0, -length / 2]
324-
port_1 = make_port(center_src1, direction="+", type=port_types[0], name="coax_port_1")
324+
port_1 = make_port(center_src1, direction="+", type=port_types[0], name="_1")
325325
center_src2 = [0, 0, length / 2]
326-
port_2 = make_port(center_src2, direction="-", type=port_types[1], name="coax_port_2")
326+
port_2 = make_port(center_src2, direction="-", type=port_types[1], name="_2")
327327
ports = [port_1, port_2]
328328
freqs = np.linspace(freq_start, freq_stop, 100)
329329

tests/test_plugins/smatrix/test_component_modeler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,35 @@ def test_mapping_exclusion(monkeypatch):
374374
_test_mappings(element_mappings, s_matrix)
375375

376376

377+
def test_mapping_with_run_only():
378+
"""Make sure that the Modeler is correctly validated when both run_only and
379+
element_mappings are provided."""
380+
ports = make_ports()
381+
382+
EXCLUDE_INDEX = ("right_bot", 0)
383+
element_mappings = []
384+
run_only = []
385+
# add a mapping to each element in the row of EXCLUDE_INDEX
386+
for port in ports:
387+
for mode_index in range(port.mode_spec.num_modes):
388+
row_index = (port.name, mode_index)
389+
run_only.append(row_index)
390+
if row_index != EXCLUDE_INDEX:
391+
mapping = ((row_index, row_index), (row_index, EXCLUDE_INDEX), +1)
392+
element_mappings.append(mapping)
393+
394+
# add the self-self coupling element to complete row
395+
mapping = ((("right_bot", 1), ("right_bot", 1)), (EXCLUDE_INDEX, EXCLUDE_INDEX), +1)
396+
element_mappings.append(mapping)
397+
398+
# Will pass, since run_only covers all source indices in element_mapping
399+
_ = make_component_modeler(element_mappings=element_mappings, run_only=run_only)
400+
401+
run_only.remove(EXCLUDE_INDEX)
402+
with pytest.raises(pydantic.ValidationError):
403+
_ = make_component_modeler(element_mappings=element_mappings, run_only=run_only)
404+
405+
377406
def test_batch_filename(tmp_path):
378407
modeler = make_component_modeler()
379408
path = modeler._batch_path

tests/test_plugins/smatrix/test_terminal_component_modeler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,3 +900,44 @@ def test_get_combined_antenna_parameters_data(monkeypatch, tmp_path):
900900
assert not np.allclose(
901901
antenna_params.radiation_efficiency, single_port_params.radiation_efficiency
902902
)
903+
904+
905+
def test_run_only_and_element_mappings(monkeypatch, tmp_path):
906+
"""Checks the terminal component modeler works when running with a subset of excitations."""
907+
z_grid = td.UniformGrid(dl=1 * 1e3)
908+
xy_grid = td.UniformGrid(dl=0.1 * 1e3)
909+
grid_spec = td.GridSpec(grid_x=xy_grid, grid_y=xy_grid, grid_z=z_grid)
910+
modeler = make_coaxial_component_modeler(
911+
path_dir=str(tmp_path), port_types=(CoaxialLumpedPort, WavePort), grid_spec=grid_spec
912+
)
913+
port0_idx = modeler.network_index(modeler.ports[0])
914+
port1_idx = modeler.network_index(modeler.ports[1])
915+
modeler_run1 = modeler.updated_copy(run_only=(port0_idx,))
916+
917+
# Make sure the smatrix and impedance calculations work for reduced simulations
918+
s_matrix = run_component_modeler(monkeypatch, modeler_run1)
919+
with pytest.raises(ValueError):
920+
TerminalComponentModeler._validate_square_matrix(s_matrix, "test_method")
921+
_ = modeler_run1.port_reference_impedances
922+
923+
assert len(modeler_run1.sim_dict) == 1
924+
S11 = (port0_idx, port0_idx)
925+
S21 = (port1_idx, port0_idx)
926+
S12 = (port0_idx, port1_idx)
927+
S22 = (port1_idx, port1_idx)
928+
element_mappings = ((S11, S22, 1),)
929+
modeler_with_mappings = modeler.updated_copy(element_mappings=element_mappings)
930+
assert len(modeler_with_mappings.sim_dict) == 2
931+
932+
# Column 1 is mapped to column 2, resulting in one simulation
933+
element_mappings = ((S11, S22, 1), (S21, S12, 1))
934+
modeler_with_mappings = modeler.updated_copy(element_mappings=element_mappings)
935+
s_matrix = run_component_modeler(monkeypatch, modeler_with_mappings)
936+
assert np.all(s_matrix.values[:, 0, 0] == s_matrix.values[:, 1, 1])
937+
assert np.all(s_matrix.values[:, 0, 1] == s_matrix.values[:, 1, 0])
938+
assert len(modeler_with_mappings.sim_dict) == 1
939+
940+
# Mapping is incomplete, so two simulations are run
941+
element_mappings = ((S11, S22, 1), (S12, S21, 1))
942+
modeler_with_mappings = modeler.updated_copy(element_mappings=element_mappings)
943+
assert len(modeler_with_mappings.sim_dict) == 2

tidy3d/plugins/smatrix/component_modelers/base.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66
from abc import ABC, abstractmethod
7-
from typing import Optional, Union, get_args
7+
from typing import Generic, Optional, TypeVar, Union, get_args
88

99
import numpy as np
1010
import pydantic.v1 as pd
@@ -13,7 +13,8 @@
1313
from tidy3d.components.data.data_array import DataArray
1414
from tidy3d.components.data.sim_data import SimulationData
1515
from tidy3d.components.simulation import Simulation
16-
from tidy3d.components.types import FreqArray
16+
from tidy3d.components.types import Complex, FreqArray
17+
from tidy3d.components.validators import assert_unique_names
1718
from tidy3d.config import config
1819
from tidy3d.constants import HERTZ
1920
from tidy3d.exceptions import SetupError, Tidy3dKeyError
@@ -31,8 +32,12 @@
3132
LumpedPortType = Union[LumpedPort, CoaxialLumpedPort]
3233
TerminalPortType = Union[LumpedPortType, WavePort]
3334

35+
# Generic type variables for matrix indices and elements
36+
IndexType = TypeVar("IndexType")
37+
ElementType = TypeVar("ElementType")
3438

35-
class AbstractComponentModeler(ABC, Tidy3dBaseModel):
39+
40+
class AbstractComponentModeler(ABC, Generic[IndexType, ElementType], Tidy3dBaseModel):
3641
"""Tool for modeling devices and computing port parameters."""
3742

3843
simulation: Simulation = pd.Field(
@@ -108,6 +113,24 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel):
108113
"fields that were not used to create the task will cause errors.",
109114
)
110115

116+
run_only: Optional[tuple[IndexType, ...]] = pd.Field(
117+
None,
118+
title="Run Only",
119+
description="Set of matrix indices that define the simulations to run. "
120+
"If ``None``, simulations will be run for all indices in the scattering matrix. "
121+
"If a tuple is given, simulations will be run only for the given matrix indices.",
122+
)
123+
124+
element_mappings: tuple[tuple[ElementType, ElementType, Complex], ...] = pd.Field(
125+
(),
126+
title="Element Mappings",
127+
description="Tuple of S matrix element mappings, each described by a tuple of "
128+
"(input_element, output_element, coefficient), where the coefficient is the "
129+
"element_mapping coefficient describing the relationship between the input and output "
130+
"matrix element. If all elements of a given column of the scattering matrix are defined "
131+
"by ``element_mappings``, the simulation corresponding to this column is skipped automatically.",
132+
)
133+
111134
@pd.validator("simulation", always=True)
112135
def _sim_has_no_sources(cls, val):
113136
"""Make sure simulation has no sources as they interfere with tool."""
@@ -131,6 +154,30 @@ def _warn_rf_license(cls, val):
131154
)
132155
return val
133156

157+
@pd.validator("element_mappings", always=True)
158+
def _validate_element_mappings(cls, element_mappings, values):
159+
"""
160+
Validate that each source index referenced in element_mappings is included in run_only.
161+
"""
162+
run_only = values.get("run_only")
163+
if run_only is None:
164+
return element_mappings
165+
166+
valid_set = set(run_only)
167+
invalid_indices = set()
168+
for mapping in element_mappings:
169+
input_element = mapping[0]
170+
output_element = mapping[1]
171+
for source_index in [input_element[1], output_element[1]]:
172+
if source_index not in valid_set:
173+
invalid_indices.add(source_index)
174+
if invalid_indices:
175+
raise SetupError(
176+
f"'element_mappings' references source index(es) {invalid_indices} "
177+
f"that are not present in run_only: {run_only}."
178+
)
179+
return element_mappings
180+
134181
@staticmethod
135182
def _task_name(port: Port, mode_index: Optional[int] = None) -> str:
136183
"""The name of a task, determined by the port of the source and mode index, if given."""
@@ -230,6 +277,44 @@ def get_port_by_name(self, port_name: str) -> Port:
230277
raise Tidy3dKeyError(f'Port "{port_name}" not found.')
231278
return ports[0]
232279

280+
@property
281+
@abstractmethod
282+
def matrix_indices_monitor(self) -> tuple[IndexType, ...]:
283+
"""Abstract property for all matrix indices that will be used to collect data."""
284+
285+
@cached_property
286+
def matrix_indices_source(self) -> tuple[IndexType, ...]:
287+
"""Tuple of all the source matrix indices, which may be less than the total number of ports."""
288+
if self.run_only is not None:
289+
return self.run_only
290+
return self.matrix_indices_monitor
291+
292+
@cached_property
293+
def matrix_indices_run_sim(self) -> tuple[IndexType, ...]:
294+
"""Tuple of all the matrix indices that will be used to run simulations."""
295+
296+
if not self.element_mappings:
297+
return self.matrix_indices_source
298+
299+
# all the (i, j) pairs in `S_ij` that are tagged as covered by `element_mappings`
300+
elements_determined_by_map = [element_out for (_, element_out, _) in self.element_mappings]
301+
302+
# loop through rows of the full s matrix and record rows that still need running.
303+
source_indices_needed = []
304+
for col_index in self.matrix_indices_source:
305+
# loop through columns and keep track of whether each element is covered by mapping.
306+
matrix_elements_covered = []
307+
for row_index in self.matrix_indices_monitor:
308+
element = (row_index, col_index)
309+
element_covered_by_map = element in elements_determined_by_map
310+
matrix_elements_covered.append(element_covered_by_map)
311+
312+
# if any matrix elements in row still not covered by map, a source is needed for row.
313+
if not all(matrix_elements_covered):
314+
source_indices_needed.append(col_index)
315+
316+
return source_indices_needed
317+
233318
@abstractmethod
234319
def _construct_smatrix(self, batch_data: BatchData) -> DataArray:
235320
"""Post process :class:`.BatchData` to generate scattering matrix."""
@@ -308,3 +393,5 @@ def sim_data_by_task_name(self, task_name: str) -> SimulationData:
308393
sim_data = self.batch_data[task_name]
309394
config.logging_level = log_level_cache
310395
return sim_data
396+
397+
_unique_port_names = assert_unique_names("ports")

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
from tidy3d.components.simulation import Simulation
1616
from tidy3d.components.source.field import ModeSource
1717
from tidy3d.components.source.time import GaussianPulse
18-
from tidy3d.components.types import Ax, Complex
18+
from tidy3d.components.types import Ax
1919
from tidy3d.components.viz import add_ax_if_none, equal_aspect
20-
from tidy3d.exceptions import SetupError
2120
from tidy3d.plugins.smatrix.ports.modal import ModalPortDataArray, Port
2221
from tidy3d.web.api.container import BatchData
2322

@@ -27,7 +26,7 @@
2726
Element = tuple[MatrixIndex, MatrixIndex] # the 'ij' in S_ij
2827

2928

30-
class ComponentModeler(AbstractComponentModeler):
29+
class ComponentModeler(AbstractComponentModeler[MatrixIndex, Element]):
3130
"""
3231
Tool for modeling devices and computing scattering matrix elements.
3332
@@ -47,52 +46,6 @@ class ComponentModeler(AbstractComponentModeler):
4746
"For each input mode, one simulation will be run with a modal source.",
4847
)
4948

50-
element_mappings: tuple[tuple[Element, Element, Complex], ...] = pd.Field(
51-
(),
52-
title="Element Mappings",
53-
description="Mapping between elements of the scattering matrix, "
54-
"as specified by pairs of ``(port name, mode index)`` matrix indices, where the "
55-
"first element of the pair is the output and the second element of the pair is the input."
56-
"Each item of ``element_mappings`` is a tuple of ``(element1, element2, c)``, where "
57-
"the scattering matrix ``Smatrix[element2]`` is set equal to ``c * Smatrix[element1]``."
58-
"If all elements of a given column of the scattering matrix are defined by "
59-
" ``element_mappings``, the simulation corresponding to this column "
60-
"is skipped automatically.",
61-
)
62-
63-
run_only: Optional[tuple[MatrixIndex, ...]] = pd.Field(
64-
None,
65-
title="Run Only",
66-
description="If given, a tuple of matrix indices, specified by (:class:`.Port`, ``int``),"
67-
" to run only, excluding the other rows from the scattering matrix. "
68-
"If this option is used, "
69-
"the data corresponding to other inputs will be missing in the resulting matrix.",
70-
)
71-
"""Finally, to exclude some rows of the scattering matrix, one can supply a ``run_only`` parameter to the
72-
:class:`ComponentModeler`. ``run_only`` contains the scattering matrix indices that the user wants to run as a
73-
source. If any indices are excluded, they will not be run."""
74-
75-
verbose: bool = pd.Field(
76-
False,
77-
title="Verbosity",
78-
description="Whether the :class:`.ComponentModeler` should print status and progressbars.",
79-
)
80-
81-
callback_url: str = pd.Field(
82-
None,
83-
title="Callback URL",
84-
description="Http PUT url to receive simulation finish event. "
85-
"The body content is a json file with fields "
86-
"``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.",
87-
)
88-
89-
@pd.validator("simulation", always=True)
90-
def _sim_has_no_sources(cls, val):
91-
"""Make sure simulation has no sources as they interfere with tool."""
92-
if len(val.sources) > 0:
93-
raise SetupError("'ComponentModeler.simulation' must not have any sources.")
94-
return val
95-
9649
@cached_property
9750
def sim_dict(self) -> dict[str, Simulation]:
9851
"""Generate all the :class:`.Simulation` objects for the S matrix calculation."""
@@ -121,39 +74,6 @@ def matrix_indices_monitor(self) -> tuple[MatrixIndex, ...]:
12174
matrix_indices.append((port.name, mode_index))
12275
return tuple(matrix_indices)
12376

124-
@cached_property
125-
def matrix_indices_source(self) -> tuple[MatrixIndex, ...]:
126-
"""Tuple of all the source matrix indices (port, mode_index) in the Component Modeler."""
127-
if self.run_only is not None:
128-
return self.run_only
129-
return self.matrix_indices_monitor
130-
131-
@cached_property
132-
def matrix_indices_run_sim(self) -> tuple[MatrixIndex, ...]:
133-
"""Tuple of all the source matrix indices (port, mode_index) in the Component Modeler."""
134-
135-
if self.element_mappings is None or self.element_mappings == {}:
136-
return self.matrix_indices_source
137-
138-
# all the (i, j) pairs in `S_ij` that are tagged as covered by `element_mappings`
139-
elements_determined_by_map = [element_out for (_, element_out, _) in self.element_mappings]
140-
141-
# loop through rows of the full s matrix and record rows that still need running.
142-
source_indices_needed = []
143-
for col_index in self.matrix_indices_source:
144-
# loop through columns and keep track of whether each element is covered by mapping.
145-
matrix_elements_covered = []
146-
for row_index in self.matrix_indices_monitor:
147-
element = (row_index, col_index)
148-
element_covered_by_map = element in elements_determined_by_map
149-
matrix_elements_covered.append(element_covered_by_map)
150-
151-
# if any matrix elements in row still not covered by map, a source is needed for row.
152-
if not all(matrix_elements_covered):
153-
source_indices_needed.append(col_index)
154-
155-
return source_indices_needed
156-
15777
@cached_property
15878
def port_names(self) -> tuple[list[str], list[str]]:
15979
"""List of port names for inputs and outputs, respectively."""

0 commit comments

Comments
 (0)