Skip to content

Commit 16bf05f

Browse files
committed
fix setting data attributes
1 parent e445dd0 commit 16bf05f

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

tidy3d/plugins/microwave/impedance_calculator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,12 @@ def check_voltage_or_current(cls, val, values):
120120
@staticmethod
121121
def _make_result_data_array(result: DataArray) -> ImpedanceResultTypes:
122122
"""Helper for creating the proper result type."""
123+
cls = ImpedanceFreqDataArray
123124
if "t" in result.coords:
124-
return ImpedanceTimeDataArray(data=result.data, coords=result.coords)
125+
cls = ImpedanceTimeDataArray
125126
if "f" in result.coords and "mode_index" in result.coords:
126-
return ImpedanceFreqModeDataArray(data=result.data, coords=result.coords)
127-
return ImpedanceFreqDataArray(data=result.data, coords=result.coords)
127+
cls = ImpedanceFreqModeDataArray
128+
return cls.assign_data_attrs(cls(data=result.data, coords=result.coords))
128129

129130
@pd.root_validator(pre=False)
130131
def _warn_rf_license(cls, values):

tidy3d/plugins/microwave/path_integrals.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import numpy as np
99
import pydantic.v1 as pd
10-
import xarray as xr
1110

1211
from tidy3d.components.base import Tidy3dBaseModel, cached_property
1312
from tidy3d.components.data.data_array import (
@@ -24,6 +23,7 @@
2423
CurrentFreqDataArray,
2524
CurrentFreqModeDataArray,
2625
CurrentTimeDataArray,
26+
DataArray,
2727
VoltageFreqDataArray,
2828
VoltageFreqModeDataArray,
2929
VoltageTimeDataArray,
@@ -220,13 +220,14 @@ def _check_monitor_data_supported(em_field: MonitorDataTypes):
220220
)
221221

222222
@staticmethod
223-
def _make_result_data_array(result: xr.DataArray) -> IntegralResultTypes:
223+
def _make_result_data_array(result: DataArray) -> IntegralResultTypes:
224224
"""Helper for creating the proper result type."""
225+
cls = FreqDataArray
225226
if "t" in result.coords:
226-
return TimeDataArray(data=result.data, coords=result.coords)
227+
cls = TimeDataArray
227228
if "f" in result.coords and "mode_index" in result.coords:
228-
return FreqModeDataArray(data=result.data, coords=result.coords)
229-
return FreqDataArray(data=result.data, coords=result.coords)
229+
cls = FreqModeDataArray
230+
return cls.assign_data_attrs(cls(data=result.data, coords=result.coords))
230231

231232

232233
class VoltageIntegralAxisAligned(AxisAlignedPathIntegral):
@@ -257,13 +258,14 @@ def compute_voltage(self, em_field: MonitorDataTypes) -> VoltageIntegralResultTy
257258
return voltage
258259

259260
@staticmethod
260-
def _make_result_data_array(result: xr.DataArray) -> VoltageIntegralResultTypes:
261+
def _make_result_data_array(result: DataArray) -> VoltageIntegralResultTypes:
261262
"""Helper for creating the proper result type."""
263+
cls = VoltageFreqDataArray
262264
if "t" in result.coords:
263-
return VoltageTimeDataArray(data=result.data, coords=result.coords)
265+
cls = VoltageTimeDataArray
264266
if "f" in result.coords and "mode_index" in result.coords:
265-
return VoltageFreqModeDataArray(data=result.data, coords=result.coords)
266-
return VoltageFreqDataArray(data=result.data, coords=result.coords)
267+
cls = VoltageFreqModeDataArray
268+
return cls.assign_data_attrs(cls(data=result.data, coords=result.coords))
267269

268270
@staticmethod
269271
def from_terminal_positions(
@@ -424,7 +426,7 @@ def compute_current(self, em_field: MonitorDataTypes) -> CurrentIntegralResultTy
424426

425427
if self.sign == "-":
426428
current *= -1
427-
current = CurrentIntegralAxisAligned._set_data_array_attributes(current)
429+
current = CurrentIntegralAxisAligned._make_result_data_array(current)
428430
return current
429431

430432
@cached_property
@@ -516,13 +518,14 @@ def _to_path_integrals(
516518
return (bottom, right, top, left)
517519

518520
@staticmethod
519-
def _make_result_data_array(result: xr.DataArray) -> CurrentIntegralResultTypes:
521+
def _make_result_data_array(result: DataArray) -> CurrentIntegralResultTypes:
520522
"""Helper for creating the proper result type."""
523+
cls = CurrentFreqDataArray
521524
if "t" in result.coords:
522-
return CurrentTimeDataArray(data=result.data, coords=result.coords)
525+
cls = CurrentTimeDataArray
523526
if "f" in result.coords and "mode_index" in result.coords:
524-
return CurrentFreqModeDataArray(data=result.data, coords=result.coords)
525-
return CurrentFreqDataArray(data=result.data, coords=result.coords)
527+
cls = CurrentFreqModeDataArray
528+
return cls.assign_data_attrs(cls(data=result.data, coords=result.coords))
526529

527530
@add_ax_if_none
528531
def plot(

0 commit comments

Comments
 (0)