Skip to content

Commit a5a99e0

Browse files
committed
immutable data arrays for rf
fix setting data attributes
1 parent 5274776 commit a5a99e0

File tree

6 files changed

+301
-85
lines changed

6 files changed

+301
-85
lines changed

tidy3d/components/data/data_array.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@
2323
from tidy3d.components.geometry.bound_ops import bounds_contains
2424
from tidy3d.components.types import Axis, Bound
2525
from tidy3d.constants import (
26+
AMP,
2627
HERTZ,
2728
MICROMETER,
29+
OHM,
2830
PICOSECOND_PER_NANOMETER_PER_KILOMETER,
2931
RADIAN,
3032
SECOND,
33+
VOLT,
3134
WATT,
3235
)
3336
from tidy3d.exceptions import DataError, FileError
@@ -1309,6 +1312,165 @@ class PerturbationCoefficientDataArray(DataArray):
13091312
_dims = ("wvl", "coeff")
13101313

13111314

1315+
class VoltageArray(DataArray):
1316+
# Always set __slots__ = () to avoid xarray warnings
1317+
__slots__ = ()
1318+
_data_attrs = {"units": VOLT, "long_name": "voltage"}
1319+
1320+
1321+
class CurrentArray(DataArray):
1322+
# Always set __slots__ = () to avoid xarray warnings
1323+
__slots__ = ()
1324+
_data_attrs = {"units": AMP, "long_name": "current"}
1325+
1326+
1327+
class ImpedanceArray(DataArray):
1328+
# Always set __slots__ = () to avoid xarray warnings
1329+
__slots__ = ()
1330+
_data_attrs = {"units": OHM, "long_name": "impedance"}
1331+
1332+
1333+
# Voltage arrays
1334+
class VoltageFreqDataArray(VoltageArray, FreqDataArray):
1335+
"""Voltage data array in frequency domain.
1336+
1337+
Example
1338+
-------
1339+
>>> import numpy as np
1340+
>>> f = [2e9, 3e9, 4e9]
1341+
>>> coords = dict(f=f)
1342+
>>> data = np.random.random(3) + 1j * np.random.random(3)
1343+
>>> vfd = VoltageFreqDataArray(data, coords=coords)
1344+
"""
1345+
1346+
__slots__ = ()
1347+
1348+
1349+
class VoltageTimeDataArray(VoltageArray, TimeDataArray):
1350+
"""Voltage data array in time domain.
1351+
1352+
Example
1353+
-------
1354+
>>> import numpy as np
1355+
>>> t = [0, 1e-9, 2e-9, 3e-9]
1356+
>>> coords = dict(t=t)
1357+
>>> data = np.sin(2 * np.pi * 1e9 * np.array(t))
1358+
>>> vtd = VoltageTimeDataArray(data, coords=coords)
1359+
"""
1360+
1361+
__slots__ = ()
1362+
1363+
1364+
class VoltageFreqModeDataArray(VoltageArray, FreqModeDataArray):
1365+
"""Voltage data array in frequency-mode domain.
1366+
1367+
Example
1368+
-------
1369+
>>> import numpy as np
1370+
>>> f = [2e9, 3e9]
1371+
>>> mode_index = [0, 1]
1372+
>>> coords = dict(f=f, mode_index=mode_index)
1373+
>>> data = np.random.random((2, 2)) + 1j * np.random.random((2, 2))
1374+
>>> vfmd = VoltageFreqModeDataArray(data, coords=coords)
1375+
"""
1376+
1377+
__slots__ = ()
1378+
1379+
1380+
# Current arrays
1381+
class CurrentFreqDataArray(CurrentArray, FreqDataArray):
1382+
"""Current data array in frequency domain.
1383+
1384+
Example
1385+
-------
1386+
>>> import numpy as np
1387+
>>> f = [2e9, 3e9, 4e9]
1388+
>>> coords = dict(f=f)
1389+
>>> data = np.random.random(3) + 1j * np.random.random(3)
1390+
>>> cfd = CurrentFreqDataArray(data, coords=coords)
1391+
"""
1392+
1393+
__slots__ = ()
1394+
1395+
1396+
class CurrentTimeDataArray(CurrentArray, TimeDataArray):
1397+
"""Current data array in time domain.
1398+
1399+
Example
1400+
-------
1401+
>>> import numpy as np
1402+
>>> t = [0, 1e-9, 2e-9, 3e-9]
1403+
>>> coords = dict(t=t)
1404+
>>> data = np.cos(2 * np.pi * 1e9 * np.array(t))
1405+
>>> ctd = CurrentTimeDataArray(data, coords=coords)
1406+
"""
1407+
1408+
__slots__ = ()
1409+
1410+
1411+
class CurrentFreqModeDataArray(CurrentArray, FreqModeDataArray):
1412+
"""Current data array in frequency-mode domain.
1413+
1414+
Example
1415+
-------
1416+
>>> import numpy as np
1417+
>>> f = [2e9, 3e9]
1418+
>>> mode_index = [0, 1]
1419+
>>> coords = dict(f=f, mode_index=mode_index)
1420+
>>> data = np.random.random((2, 2)) + 1j * np.random.random((2, 2))
1421+
>>> cfmd = CurrentFreqModeDataArray(data, coords=coords)
1422+
"""
1423+
1424+
__slots__ = ()
1425+
1426+
1427+
# Impedance arrays
1428+
class ImpedanceFreqDataArray(ImpedanceArray, FreqDataArray):
1429+
"""Impedance data array in frequency domain.
1430+
1431+
Example
1432+
-------
1433+
>>> import numpy as np
1434+
>>> f = [2e9, 3e9, 4e9]
1435+
>>> coords = dict(f=f)
1436+
>>> data = 50.0 + 1j * np.random.random(3)
1437+
>>> zfd = ImpedanceFreqDataArray(data, coords=coords)
1438+
"""
1439+
1440+
__slots__ = ()
1441+
1442+
1443+
class ImpedanceTimeDataArray(ImpedanceArray, TimeDataArray):
1444+
"""Impedance data array in time domain.
1445+
1446+
Example
1447+
-------
1448+
>>> import numpy as np
1449+
>>> t = [0, 1e-9, 2e-9, 3e-9]
1450+
>>> coords = dict(t=t)
1451+
>>> data = 50.0 * np.ones_like(t)
1452+
>>> ztd = ImpedanceTimeDataArray(data, coords=coords)
1453+
"""
1454+
1455+
__slots__ = ()
1456+
1457+
1458+
class ImpedanceFreqModeDataArray(ImpedanceArray, FreqModeDataArray):
1459+
"""Impedance data array in frequency-mode domain.
1460+
1461+
Example
1462+
-------
1463+
>>> import numpy as np
1464+
>>> f = [2e9, 3e9]
1465+
>>> mode_index = [0, 1]
1466+
>>> coords = dict(f=f, mode_index=mode_index)
1467+
>>> data = 50.0 + 10.0 * np.random.random((2, 2))
1468+
>>> zfmd = ImpedanceFreqModeDataArray(data, coords=coords)
1469+
"""
1470+
1471+
__slots__ = ()
1472+
1473+
13121474
DATA_ARRAY_TYPES = [
13131475
SpatialDataArray,
13141476
ScalarFieldDataArray,
@@ -1346,6 +1508,15 @@ class PerturbationCoefficientDataArray(DataArray):
13461508
SpatialVoltageDataArray,
13471509
PerturbationCoefficientDataArray,
13481510
IndexedTimeDataArray,
1511+
VoltageFreqDataArray,
1512+
VoltageTimeDataArray,
1513+
VoltageFreqModeDataArray,
1514+
CurrentFreqDataArray,
1515+
CurrentTimeDataArray,
1516+
CurrentFreqModeDataArray,
1517+
ImpedanceFreqDataArray,
1518+
ImpedanceTimeDataArray,
1519+
ImpedanceFreqModeDataArray,
13491520
]
13501521
DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES}
13511522

@@ -1356,3 +1527,14 @@ class PerturbationCoefficientDataArray(DataArray):
13561527
IndexedFieldVoltageDataArray,
13571528
PointDataArray,
13581529
]
1530+
1531+
IntegralResultTypes = Union[FreqDataArray, FreqModeDataArray, TimeDataArray]
1532+
VoltageIntegralResultTypes = Union[
1533+
VoltageFreqDataArray, VoltageFreqModeDataArray, VoltageTimeDataArray
1534+
]
1535+
CurrentIntegralResultTypes = Union[
1536+
CurrentFreqDataArray, CurrentFreqModeDataArray, CurrentTimeDataArray
1537+
]
1538+
ImpedanceResultTypes = Union[
1539+
ImpedanceFreqDataArray, ImpedanceFreqModeDataArray, ImpedanceTimeDataArray
1540+
]

tidy3d/plugins/microwave/custom_path_integrals.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from .path_integrals import (
2020
AbstractAxesRH,
2121
AxisAlignedPathIntegral,
22-
CurrentIntegralAxisAligned,
22+
CurrentIntegralResultTypes,
2323
IntegralResultTypes,
2424
MonitorDataTypes,
25-
VoltageIntegralAxisAligned,
25+
VoltageIntegralResultTypes,
2626
)
2727
from .viz import (
2828
ARROW_CURRENT,
@@ -89,6 +89,10 @@ def compute_integral(
8989
Result of integral over remaining dimensions (frequency, time, mode indices).
9090
"""
9191

92+
from tidy3d.plugins.smatrix.utils import (
93+
_make_base_result_data_array,
94+
)
95+
9296
(dim1, dim2, dim3) = self.local_dims
9397

9498
h_field_name = f"{field}{dim1}"
@@ -130,7 +134,7 @@ def compute_integral(
130134
# Integrate along the path
131135
result = integrand.integrate(coord="s")
132136
result = result.reset_coords(drop=True)
133-
return AxisAlignedPathIntegral._make_result_data_array(result)
137+
return _make_base_result_data_array(result)
134138

135139
@staticmethod
136140
def _compute_dl_component(coord_array: xr.DataArray, closed_contour=False) -> np.array:
@@ -243,7 +247,7 @@ class CustomVoltageIntegral2D(CustomPathIntegral2D):
243247
244248
.. TODO Improve by including extrapolate_to_endpoints field, non-trivial extension."""
245249

246-
def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
250+
def compute_voltage(self, em_field: MonitorDataTypes) -> VoltageIntegralResultTypes:
247251
"""Compute voltage along path defined by a line.
248252
249253
Parameters
@@ -253,13 +257,16 @@ def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
253257
254258
Returns
255259
-------
256-
:class:`.IntegralResultTypes`
260+
:class:`.VoltageIntegralResultTypes`
257261
Result of voltage computation over remaining dimensions (frequency, time, mode indices).
258262
"""
263+
from tidy3d.plugins.smatrix.utils import (
264+
_make_voltage_data_array,
265+
)
266+
259267
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
260268
voltage = -1.0 * self.compute_integral(field="E", em_field=em_field)
261-
voltage = VoltageIntegralAxisAligned._set_data_array_attributes(voltage)
262-
return voltage
269+
return _make_voltage_data_array(voltage)
263270

264271
@add_ax_if_none
265272
def plot(
@@ -316,7 +323,7 @@ class CustomCurrentIntegral2D(CustomPathIntegral2D):
316323
To compute the current flowing in the positive ``axis`` direction, the vertices should be
317324
ordered in a counterclockwise direction."""
318325

319-
def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
326+
def compute_current(self, em_field: MonitorDataTypes) -> CurrentIntegralResultTypes:
320327
"""Compute current flowing in a custom loop.
321328
322329
Parameters
@@ -326,13 +333,16 @@ def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
326333
327334
Returns
328335
-------
329-
:class:`.IntegralResultTypes`
336+
:class:`.CurrentIntegralResultTypes`
330337
Result of current computation over remaining dimensions (frequency, time, mode indices).
331338
"""
339+
from tidy3d.plugins.smatrix.utils import (
340+
_make_current_data_array,
341+
)
342+
332343
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
333344
current = self.compute_integral(field="H", em_field=em_field)
334-
current = CurrentIntegralAxisAligned._set_data_array_attributes(current)
335-
return current
345+
return _make_current_data_array(current)
336346

337347
@add_ax_if_none
338348
def plot(

tidy3d/plugins/microwave/impedance_calculator.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,16 @@
88
import pydantic.v1 as pd
99

1010
from tidy3d.components.base import Tidy3dBaseModel
11-
from tidy3d.components.data.data_array import FreqDataArray, FreqModeDataArray, TimeDataArray
11+
from tidy3d.components.data.data_array import ImpedanceResultTypes
1212
from tidy3d.components.data.monitor_data import FieldTimeData
1313
from tidy3d.components.monitor import ModeMonitor, ModeSolverMonitor
14-
from tidy3d.constants import OHM
1514
from tidy3d.exceptions import ValidationError
1615
from tidy3d.log import log
1716

1817
from .custom_path_integrals import CustomCurrentIntegral2D, CustomVoltageIntegral2D
1918
from .path_integrals import (
2019
AxisAlignedPathIntegral,
2120
CurrentIntegralAxisAligned,
22-
IntegralResultTypes,
2321
MonitorDataTypes,
2422
VoltageIntegralAxisAligned,
2523
)
@@ -43,7 +41,7 @@ class ImpedanceCalculator(Tidy3dBaseModel):
4341
description="Definition of contour integral for computing current.",
4442
)
4543

46-
def compute_impedance(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
44+
def compute_impedance(self, em_field: MonitorDataTypes) -> ImpedanceResultTypes:
4745
"""Compute impedance for the supplied ``em_field`` using ``voltage_integral`` and
4846
``current_integral``. If only a single integral has been defined, impedance is
4947
computed using the total flux in ``em_field``.
@@ -56,9 +54,11 @@ def compute_impedance(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
5654
5755
Returns
5856
-------
59-
:class:`.IntegralResultTypes`
57+
:class:`.ImpedanceResultTypes`
6058
Result of impedance computation over remaining dimensions (frequency, time, mode indices).
6159
"""
60+
from tidy3d.plugins.smatrix.utils import _make_impedance_data_array
61+
6262
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
6363

6464
# If both voltage and current integrals have been defined then impedance is computed directly
@@ -98,7 +98,7 @@ def compute_impedance(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
9898
impedance = np.real(voltage) / np.real(current)
9999
else:
100100
impedance = voltage / current
101-
impedance = ImpedanceCalculator._set_data_array_attributes(impedance)
101+
impedance = _make_impedance_data_array(impedance)
102102
return impedance
103103

104104
@pd.validator("current_integral", always=True)
@@ -111,19 +111,6 @@ def check_voltage_or_current(cls, val, values):
111111
)
112112
return val
113113

114-
@staticmethod
115-
def _set_data_array_attributes(data_array: IntegralResultTypes) -> IntegralResultTypes:
116-
"""Helper to set additional metadata for ``IntegralResultTypes``."""
117-
# Determine type based on coords present
118-
if "mode_index" in data_array.coords:
119-
data_array = FreqModeDataArray(data_array)
120-
elif "f" in data_array.coords:
121-
data_array = FreqDataArray(data_array)
122-
else:
123-
data_array = TimeDataArray(data_array)
124-
data_array.name = "Z0"
125-
return data_array.assign_attrs(units=OHM, long_name="characteristic impedance")
126-
127114
@pd.root_validator(pre=False)
128115
def _warn_rf_license(cls, values):
129116
log.warning(

0 commit comments

Comments
 (0)