Skip to content

Commit 257c060

Browse files
maybe need damian's help
1 parent 6451cc1 commit 257c060

File tree

2 files changed

+42
-35
lines changed

2 files changed

+42
-35
lines changed

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def _construct_smatrix(self) -> TerminalPortDataArray:
438438

439439
return _construct_smatrix(self)
440440

441-
def _internal_construct_smatrix(self, batch_data: typing.Any) -> TerminalPortDataArray:
441+
def _internal_construct_smatrix(self, batch_data: typing.Any = None) -> TerminalPortDataArray:
442442
from tidy3d.plugins.smatrix.run import _internal_construct_smatrix
443443

444-
return _internal_construct_smatrix(self)
444+
return _internal_construct_smatrix(self, batch_data=batch_data)

tidy3d/plugins/smatrix/run.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import typing
6+
57
import numpy as np
68

79
from tidy3d.components.data.sim_data import SimulationData
@@ -47,39 +49,6 @@ def _port_reference_impedances(simulation) -> PortDataArray:
4749
return port_impedances
4850

4951

50-
def _internal_construct_smatrix(simulation) -> TerminalPortDataArray:
51-
"""Post process :class:`.BatchData` to generate scattering matrix, for internal use only."""
52-
from tidy3d.plugins.smatrix.utils import ab_to_s
53-
54-
port_names = [port.name for port in simulation.ports]
55-
56-
values = np.zeros(
57-
(len(simulation.freqs), len(port_names), len(port_names)),
58-
dtype=complex,
59-
)
60-
coords = {
61-
"f": np.array(simulation.freqs),
62-
"port_out": port_names,
63-
"port_in": port_names,
64-
}
65-
a_matrix = TerminalPortDataArray(values, coords=coords)
66-
b_matrix = a_matrix.copy(deep=True)
67-
68-
# Tabulate the reference impedances at each port and frequency
69-
port_impedances = _port_reference_impedances(simulation=simulation)
70-
71-
# loop through source ports
72-
for port_in in simulation.ports:
73-
sim_data = simulation.batch_data[simulation._task_name(port=port_in)]
74-
a, b = simulation.compute_power_wave_amplitudes_at_each_port(port_impedances, sim_data)
75-
indexer = {"f": a.f, "port_in": port_in.name, "port_out": a.port}
76-
a_matrix.loc[indexer] = a
77-
b_matrix.loc[indexer] = b
78-
79-
s_matrix = ab_to_s(a_matrix, b_matrix)
80-
return s_matrix
81-
82-
8352
def compute_power_wave_amplitudes_at_each_port(
8453
simulation, port_reference_impedances: PortDataArray, sim_data: SimulationData
8554
) -> tuple[PortDataArray, PortDataArray]:
@@ -138,3 +107,41 @@ def compute_power_wave_amplitudes_at_each_port(
138107
b.values = F_numpy * (V_numpy - np.conj(Z_numpy) * I_numpy)
139108

140109
return a, b
110+
111+
112+
def _internal_construct_smatrix(simulation, batch_data: typing.Any = None) -> TerminalPortDataArray:
113+
"""Post process :class:`.BatchData` to generate scattering matrix, for internal use only."""
114+
if batch_data:
115+
pass
116+
else:
117+
batch_data = simulation.batch_data
118+
119+
from tidy3d.plugins.smatrix.utils import ab_to_s
120+
121+
port_names = [port.name for port in simulation.ports]
122+
123+
values = np.zeros(
124+
(len(simulation.freqs), len(port_names), len(port_names)),
125+
dtype=complex,
126+
)
127+
coords = {
128+
"f": np.array(simulation.freqs),
129+
"port_out": port_names,
130+
"port_in": port_names,
131+
}
132+
a_matrix = TerminalPortDataArray(values, coords=coords)
133+
b_matrix = a_matrix.copy(deep=True)
134+
135+
# Tabulate the reference impedances at each port and frequency
136+
port_impedances = _port_reference_impedances(simulation=simulation)
137+
138+
# loop through source ports
139+
for port_in in simulation.ports:
140+
sim_data = batch_data[simulation._task_name(port=port_in)]
141+
a, b = compute_power_wave_amplitudes_at_each_port(simulation, port_impedances, sim_data)
142+
indexer = {"f": a.f, "port_in": port_in.name, "port_out": a.port}
143+
a_matrix.loc[indexer] = a
144+
b_matrix.loc[indexer] = b
145+
146+
s_matrix = ab_to_s(a_matrix, b_matrix)
147+
return s_matrix

0 commit comments

Comments
 (0)