diff --git a/CHANGELOG.md b/CHANGELOG.md index 997a2ce476..7fc2ff04c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Add support for computing `outer_dot` directly with monitor data from EME simulations. + ## [2.9.0rc2] - 2025-07-17 ### Added diff --git a/tests/test_components/test_eme.py b/tests/test_components/test_eme.py index bf30e940f1..73ed1eeefa 100644 --- a/tests/test_components/test_eme.py +++ b/tests/test_components/test_eme.py @@ -1,5 +1,7 @@ from __future__ import annotations +import itertools + import numpy as np import pydantic.v1 as pd import pytest @@ -48,6 +50,7 @@ def make_eme_sim(): # field monitor stores field on FDTD grid field_monitor = td.EMEFieldMonitor(size=(0, td.inf, td.inf), name="field", colocate=True) + field_monitor2 = td.EMEFieldMonitor(size=(td.inf, td.inf, 0), name="field2", colocate=True) coeff_monitor = td.EMECoefficientMonitor( size=monitor_size, @@ -74,7 +77,7 @@ def make_eme_sim(): name="modes_out", ) - monitors = [mode_monitor, coeff_monitor, field_monitor, modes_in, modes_out] + monitors = [mode_monitor, coeff_monitor, field_monitor, modes_in, modes_out, field_monitor2] structures = [waveguide] sim = td.EMESimulation( @@ -605,8 +608,8 @@ def test_eme_simulation(): def _get_eme_scalar_mode_field_data_array(num_sweep=0): - x = np.linspace(-1, 1, 35) - y = np.linspace(-1, 1, 38) + x = np.linspace(-1.5, 1.5, 35) + y = np.linspace(-1.5, 1.5, 38) z = [3] f = [td.C_0, 3e14] mode_index = np.arange(10) @@ -632,6 +635,7 @@ def _get_eme_scalar_mode_field_data_array(num_sweep=0): coords=coords, ) data[:, :, :, :, 0, :, 1] = np.nan + data = data.drop_vars("z") if num_sweep == 0: data = data.drop_vars("sweep_index") return data @@ -671,6 +675,36 @@ def _get_eme_scalar_field_data_array(num_sweep=0): return data +def _get_eme_scalar_field2_data_array(num_sweep=0): + x = np.linspace(-1.5, 1.5, 35) + y = np.linspace(-1.5, 1.5, 38) + z = [0] + f = [td.C_0, 3e14] + mode_index = np.arange(5) + eme_port_index = [0, 1] + if num_sweep != 0: + sweep_index = np.arange(num_sweep) + else: + sweep_index = [0] + coords = { + "x": x, + "y": y, + "z": z, + "f": f, + "sweep_index": sweep_index, + "eme_port_index": eme_port_index, + "mode_index": mode_index, + } + data = td.EMEScalarFieldDataArray( + (1 + 1j) * np.random.random((len(x), len(y), len(z), 2, len(sweep_index), 2, 5)), + coords=coords, + ) + data[:, :, :, :, 0, 0, 0] = np.nan + if num_sweep == 0: + data = data.drop_vars("sweep_index") + return data + + def test_eme_scalar_field_data_array(): _ = _get_eme_scalar_field_data_array() @@ -822,6 +856,12 @@ def _get_eme_field_dataset(num_sweep=0): return td.EMEFieldDataset(**fields) +def _get_eme_field2_dataset(num_sweep=0): + field = _get_eme_scalar_field2_data_array(num_sweep=num_sweep) + fields = dict.fromkeys(["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], field) + return td.EMEFieldDataset(**fields) + + def test_eme_dataset(): # test s matrix _ = _get_eme_smatrix_dataset() @@ -877,19 +917,33 @@ def _get_eme_mode_solver_data(num_sweep=0): if num_sweep == 0: grid_primal_correction = grid_primal_correction.drop_vars("sweep_index") grid_dual_correction = grid_dual_correction.drop_vars("sweep_index") + sim = make_eme_sim() + grid_expanded = sim.discretize_monitor(monitor) return td.EMEModeSolverData( monitor=monitor, grid_primal_correction=grid_primal_correction, grid_dual_correction=grid_dual_correction, + grid_expanded=grid_expanded, **kwargs, ) def _get_eme_field_data(num_sweep=0): + sim = make_eme_sim() dataset = _get_eme_field_dataset(num_sweep=num_sweep) kwargs = dataset.field_components monitor = td.EMEFieldMonitor(size=(0, td.inf, td.inf), name="field", colocate=True) - return td.EMEFieldData(monitor=monitor, **kwargs) + grid_expanded = sim.discretize_monitor(monitor) + return td.EMEFieldData(monitor=monitor, **kwargs, grid_expanded=grid_expanded) + + +def _get_eme_field2_data(num_sweep=0): + sim = make_eme_sim() + dataset = _get_eme_field2_dataset(num_sweep=num_sweep) + kwargs = dataset.field_components + monitor = td.EMEFieldMonitor(size=(td.inf, td.inf, 0), name="field2", colocate=True) + grid_expanded = sim.discretize_monitor(monitor) + return td.EMEFieldData(monitor=monitor, **kwargs, grid_expanded=grid_expanded) def _get_eme_coeff_data(num_sweep=0): @@ -955,12 +1009,14 @@ def test_eme_sim_data(): mode_monitor_data = _get_eme_mode_solver_data() coeff_monitor_data = _get_eme_coeff_data() field_monitor_data = _get_eme_field_data() + field2_monitor_data = _get_eme_field2_data() modes_in_data = _get_mode_solver_data(modes_out=False, num_modes=3) modes_out_data = _get_mode_solver_data(modes_out=True, num_modes=2) data = [ mode_monitor_data, coeff_monitor_data, field_monitor_data, + field2_monitor_data, modes_in_data, modes_out_data, ] @@ -1221,6 +1277,14 @@ def test_eme_sim_data(): field_in_basis = sim_data.field_in_basis(field=sim_data["field"], modes=modes_in0, port_index=1) assert "mode_index" not in field_in_basis.Ex.coords + # test dot and outer dot with EME field and mode data + eme_field_data = sim_data["field2"] + mode_data = sim_data.port_modes_list_sweep[0][0] + eme_mode_data = sim_data["modes"] + datas = [eme_field_data, mode_data, eme_mode_data] + for data1, data2 in itertools.product(datas, datas): + _ = data1.outer_dot(data2) + def test_eme_sim_subsection(): eme_sim = td.EMESimulation( diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 0557f0ec4b..39cd014e4f 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -438,14 +438,27 @@ def _grid_correction_dict(self): "grid_dual_correction": self.grid_dual_correction, } + @property + def _normal_axis(self) -> int: + """For a 2D monitor data, return the normal axis. + For an EMEModeSolverMonitor, return the propagation axis.""" + # special treatment for EMEModeSolverMonitor + test_field = list(self.field_components.values())[0] + if "eme_cell_index" in test_field.coords: + for axis, dim in enumerate(list("xyz")): + if dim not in test_field.coords: + return axis + elif len(self.monitor.zero_dims) != 1: + raise DataError("Data must be 2D to get tangential dimensions.") + return self.monitor.zero_dims[0] + @property def _tangential_dims(self) -> list[str]: """For a 2D monitor data, return the names of the tangential dimensions. Raise if cannot confirm that the associated monitor is 2D.""" - if len(self.monitor.zero_dims) != 1: - raise DataError("Data must be 2D to get tangential dimensions.") + tangential_dims = ["x", "y", "z"] - tangential_dims.pop(self.monitor.zero_dims[0]) + tangential_dims.pop(self._normal_axis) return tangential_dims @@ -514,7 +527,7 @@ def _diff_area(self) -> DataArray: coords = [bs.copy() for bs in self._plane_grid_centers] # Append the first and last boundary - _, plane_inds = self.monitor.pop_axis([0, 1, 2], self.monitor.size.index(0.0)) + _, plane_inds = self.monitor.pop_axis([0, 1, 2], self._normal_axis) coords[0] = np.array([bounds[0][0], *coords[0].tolist(), bounds[0][-1]]) coords[1] = np.array([bounds[1][0], *coords[1].tolist(), bounds[1][-1]]) @@ -551,14 +564,11 @@ def _tangential_corrected(self, fields: dict[str, DataArray]) -> dict[str, DataA poynting, flux, and dot-like methods. The normal coordinate is dropped from the field data. """ - if len(self.monitor.zero_dims) != 1: - raise DataError("Data must be 2D to get tangential fields.") - # Tangential field components tan_dims = self._tangential_dims components = [fname + dim for fname in "EH" for dim in tan_dims] - normal_dim = "xyz"[self.monitor.size.index(0)] + normal_dim = "xyz"[self._normal_axis] tan_fields = {} for component in components: @@ -853,7 +863,6 @@ def outer_dot( -------- :member:`dot` """ - tan_dims = self._tangential_dims if not all(a == b for a, b in zip(tan_dims, field_data._tangential_dims)): @@ -905,7 +914,9 @@ def outer_dot( dim={"mode_index_1": [0]}, axis=len(fields_other[key].shape) ) - d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy() + fields_self, fields_other, d_area = self._align_fields( + fields_self, fields_other, self._diff_area + ) # function to apply at each pair of mode indices before integrating def fn(fields_1, fields_2): @@ -922,7 +933,7 @@ def fn(fields_1, fields_2): e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1 h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1 - summand = 0.25 * (e_self_x_h_other - h_self_x_e_other) * d_area + summand = 0.25 * (e_self_x_h_other - h_self_x_e_other) return summand result = self._outer_fn_summation( @@ -932,6 +943,7 @@ def fn(fields_1, fields_2): outer_dim_2="mode_index_1", sum_dims=tan_dims, fn=fn, + d_area=d_area, ) # Remove mode index coordinate if the input did not have it @@ -942,6 +954,49 @@ def fn(fields_1, fields_2): return result + @staticmethod + def _align_fields( + fields_1: dict[str, xr.DataArray], fields_2: dict[str, xr.DataArray], d_area: xr.DataArray + ) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray], xr.DataArray]: + """Align the fields for dot or outer_dot, inserting any missing dimensions.""" + exclude_dims = ["mode_index_0", "mode_index_1", "x", "y", "z"] + for key, field_1 in fields_1.items(): + field_2 = fields_2[key] + # remove non-coord dims + for dim in field_1.dims: + if dim not in field_1.coords: + field_1 = field_1.isel({dim: 0}) + for dim in field_2.dims: + if dim not in field_2.coords: + field_2 = field_2.isel({dim: 0}) + field_1, field_2 = xr.align( + field_1, field_2, join="inner", copy=False, exclude=exclude_dims + ) + field_1, field_2 = xr.broadcast(field_1, field_2, exclude=exclude_dims) + dims1 = [] + mode_index_dim1 = None + mode_index_dim2 = None + for dim in field_1.dims: + if dim == "mode_index": + mode_index_dim1 = "mode_index" + mode_index_dim2 = "mode_index" + elif dim == "mode_index_0": + mode_index_dim1 = "mode_index_0" + mode_index_dim2 = "mode_index_1" + else: + dims1.append(dim) + dims2 = list(dims1) + dims1.append(mode_index_dim1) + dims2.append(mode_index_dim2) + field_1 = field_1.transpose(*dims1) + field_2 = field_2.transpose(*dims2) + fields_1[key] = field_1 + fields_2[key] = field_2 + d_area, _ = xr.align( + d_area, list(fields_1.values())[0], join="inner", copy=False, exclude=exclude_dims + ) + return fields_1, fields_2, d_area + @staticmethod def _outer_fn_summation( fields_1: dict[str, xr.DataArray], @@ -950,6 +1005,7 @@ def _outer_fn_summation( outer_dim_2: str, sum_dims: list[str], fn: Callable, + d_area: Optional[xr.DataArray] = None, ) -> DataArray: """ Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``. @@ -960,6 +1016,10 @@ def _outer_fn_summation( # first, convert to numpy outside the loop to reduce xarray overhead fields_1_numpy = {key: val.to_numpy() for key, val in fields_1.items()} fields_2_numpy = {key: val.to_numpy() for key, val in fields_2.items()} + if d_area is None: + d_area_numpy = 1 + else: + d_area_numpy = d_area.to_numpy() # get one of the data arrays to look at for indexing # assuming all data arrays have the same structure @@ -1002,7 +1062,7 @@ def _outer_fn_summation( fields_1_curr = {key: val[tuple(idx_1)] for key, val in fields_1_numpy.items()} fields_2_curr = {key: val[tuple(idx_2)] for key, val in fields_2_numpy.items()} summand_curr = fn(fields_1_curr, fields_2_curr) - data_curr = np.sum(summand_curr, axis=tuple(sum_axes)) + data_curr = np.sum(summand_curr * d_area_numpy, axis=tuple(sum_axes)) data[tuple(idx_data)] = data_curr return DataArray(data, coords=coords)