|
7 | 7 |
|
8 | 8 | import numpy as np
|
9 | 9 | import pydantic.v1 as pd
|
10 |
| -import xarray as xr |
11 | 10 |
|
12 | 11 | from tidy3d.components.base import Tidy3dBaseModel, cached_property
|
13 | 12 | from tidy3d.components.data.data_array import (
|
|
24 | 23 | CurrentFreqDataArray,
|
25 | 24 | CurrentFreqModeDataArray,
|
26 | 25 | CurrentTimeDataArray,
|
| 26 | + DataArray, |
27 | 27 | VoltageFreqDataArray,
|
28 | 28 | VoltageFreqModeDataArray,
|
29 | 29 | VoltageTimeDataArray,
|
@@ -220,13 +220,14 @@ def _check_monitor_data_supported(em_field: MonitorDataTypes):
|
220 | 220 | )
|
221 | 221 |
|
222 | 222 | @staticmethod
|
223 |
| - def _make_result_data_array(result: xr.DataArray) -> IntegralResultTypes: |
| 223 | + def _make_result_data_array(result: DataArray) -> IntegralResultTypes: |
224 | 224 | """Helper for creating the proper result type."""
|
| 225 | + cls = FreqDataArray |
225 | 226 | if "t" in result.coords:
|
226 |
| - return TimeDataArray(data=result.data, coords=result.coords) |
| 227 | + cls = TimeDataArray |
227 | 228 | if "f" in result.coords and "mode_index" in result.coords:
|
228 |
| - return FreqModeDataArray(data=result.data, coords=result.coords) |
229 |
| - return FreqDataArray(data=result.data, coords=result.coords) |
| 229 | + cls = FreqModeDataArray |
| 230 | + return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) |
230 | 231 |
|
231 | 232 |
|
232 | 233 | class VoltageIntegralAxisAligned(AxisAlignedPathIntegral):
|
@@ -257,13 +258,14 @@ def compute_voltage(self, em_field: MonitorDataTypes) -> VoltageIntegralResultTy
|
257 | 258 | return voltage
|
258 | 259 |
|
259 | 260 | @staticmethod
|
260 |
| - def _make_result_data_array(result: xr.DataArray) -> VoltageIntegralResultTypes: |
| 261 | + def _make_result_data_array(result: DataArray) -> VoltageIntegralResultTypes: |
261 | 262 | """Helper for creating the proper result type."""
|
| 263 | + cls = VoltageFreqDataArray |
262 | 264 | if "t" in result.coords:
|
263 |
| - return VoltageTimeDataArray(data=result.data, coords=result.coords) |
| 265 | + cls = VoltageTimeDataArray |
264 | 266 | if "f" in result.coords and "mode_index" in result.coords:
|
265 |
| - return VoltageFreqModeDataArray(data=result.data, coords=result.coords) |
266 |
| - return VoltageFreqDataArray(data=result.data, coords=result.coords) |
| 267 | + cls = VoltageFreqModeDataArray |
| 268 | + return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) |
267 | 269 |
|
268 | 270 | @staticmethod
|
269 | 271 | def from_terminal_positions(
|
@@ -424,7 +426,7 @@ def compute_current(self, em_field: MonitorDataTypes) -> CurrentIntegralResultTy
|
424 | 426 |
|
425 | 427 | if self.sign == "-":
|
426 | 428 | current *= -1
|
427 |
| - current = CurrentIntegralAxisAligned._set_data_array_attributes(current) |
| 429 | + current = CurrentIntegralAxisAligned._make_result_data_array(current) |
428 | 430 | return current
|
429 | 431 |
|
430 | 432 | @cached_property
|
@@ -516,13 +518,14 @@ def _to_path_integrals(
|
516 | 518 | return (bottom, right, top, left)
|
517 | 519 |
|
518 | 520 | @staticmethod
|
519 |
| - def _make_result_data_array(result: xr.DataArray) -> CurrentIntegralResultTypes: |
| 521 | + def _make_result_data_array(result: DataArray) -> CurrentIntegralResultTypes: |
520 | 522 | """Helper for creating the proper result type."""
|
| 523 | + cls = CurrentFreqDataArray |
521 | 524 | if "t" in result.coords:
|
522 |
| - return CurrentTimeDataArray(data=result.data, coords=result.coords) |
| 525 | + cls = CurrentTimeDataArray |
523 | 526 | if "f" in result.coords and "mode_index" in result.coords:
|
524 |
| - return CurrentFreqModeDataArray(data=result.data, coords=result.coords) |
525 |
| - return CurrentFreqDataArray(data=result.data, coords=result.coords) |
| 527 | + cls = CurrentFreqModeDataArray |
| 528 | + return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) |
526 | 529 |
|
527 | 530 | @add_ax_if_none
|
528 | 531 | def plot(
|
|
0 commit comments