Skip to content

Add support for EME monitor data in outer_dot #2683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Add support for computing `outer_dot` directly with monitor data from EME simulations.

## [2.9.0rc2] - 2025-07-17

### Added
Expand Down
72 changes: 68 additions & 4 deletions tests/test_components/test_eme.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import itertools

import numpy as np
import pydantic.v1 as pd
import pytest
Expand Down Expand Up @@ -48,6 +50,7 @@ def make_eme_sim():

# field monitor stores field on FDTD grid
field_monitor = td.EMEFieldMonitor(size=(0, td.inf, td.inf), name="field", colocate=True)
field_monitor2 = td.EMEFieldMonitor(size=(td.inf, td.inf, 0), name="field2", colocate=True)

coeff_monitor = td.EMECoefficientMonitor(
size=monitor_size,
Expand All @@ -74,7 +77,7 @@ def make_eme_sim():
name="modes_out",
)

monitors = [mode_monitor, coeff_monitor, field_monitor, modes_in, modes_out]
monitors = [mode_monitor, coeff_monitor, field_monitor, modes_in, modes_out, field_monitor2]
structures = [waveguide]

sim = td.EMESimulation(
Expand Down Expand Up @@ -605,8 +608,8 @@ def test_eme_simulation():


def _get_eme_scalar_mode_field_data_array(num_sweep=0):
x = np.linspace(-1, 1, 35)
y = np.linspace(-1, 1, 38)
x = np.linspace(-1.5, 1.5, 35)
y = np.linspace(-1.5, 1.5, 38)
z = [3]
f = [td.C_0, 3e14]
mode_index = np.arange(10)
Expand All @@ -632,6 +635,7 @@ def _get_eme_scalar_mode_field_data_array(num_sweep=0):
coords=coords,
)
data[:, :, :, :, 0, :, 1] = np.nan
data = data.drop_vars("z")
if num_sweep == 0:
data = data.drop_vars("sweep_index")
return data
Expand Down Expand Up @@ -671,6 +675,36 @@ def _get_eme_scalar_field_data_array(num_sweep=0):
return data


def _get_eme_scalar_field2_data_array(num_sweep=0):
x = np.linspace(-1.5, 1.5, 35)
y = np.linspace(-1.5, 1.5, 38)
z = [0]
f = [td.C_0, 3e14]
mode_index = np.arange(5)
eme_port_index = [0, 1]
if num_sweep != 0:
sweep_index = np.arange(num_sweep)
else:
sweep_index = [0]
coords = {
"x": x,
"y": y,
"z": z,
"f": f,
"sweep_index": sweep_index,
"eme_port_index": eme_port_index,
"mode_index": mode_index,
}
data = td.EMEScalarFieldDataArray(
(1 + 1j) * np.random.random((len(x), len(y), len(z), 2, len(sweep_index), 2, 5)),
coords=coords,
)
data[:, :, :, :, 0, 0, 0] = np.nan
if num_sweep == 0:
data = data.drop_vars("sweep_index")
return data


def test_eme_scalar_field_data_array():
_ = _get_eme_scalar_field_data_array()

Expand Down Expand Up @@ -822,6 +856,12 @@ def _get_eme_field_dataset(num_sweep=0):
return td.EMEFieldDataset(**fields)


def _get_eme_field2_dataset(num_sweep=0):
field = _get_eme_scalar_field2_data_array(num_sweep=num_sweep)
fields = dict.fromkeys(["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], field)
return td.EMEFieldDataset(**fields)


def test_eme_dataset():
# test s matrix
_ = _get_eme_smatrix_dataset()
Expand Down Expand Up @@ -877,19 +917,33 @@ def _get_eme_mode_solver_data(num_sweep=0):
if num_sweep == 0:
grid_primal_correction = grid_primal_correction.drop_vars("sweep_index")
grid_dual_correction = grid_dual_correction.drop_vars("sweep_index")
sim = make_eme_sim()
grid_expanded = sim.discretize_monitor(monitor)
return td.EMEModeSolverData(
monitor=monitor,
grid_primal_correction=grid_primal_correction,
grid_dual_correction=grid_dual_correction,
grid_expanded=grid_expanded,
**kwargs,
)


def _get_eme_field_data(num_sweep=0):
sim = make_eme_sim()
dataset = _get_eme_field_dataset(num_sweep=num_sweep)
kwargs = dataset.field_components
monitor = td.EMEFieldMonitor(size=(0, td.inf, td.inf), name="field", colocate=True)
return td.EMEFieldData(monitor=monitor, **kwargs)
grid_expanded = sim.discretize_monitor(monitor)
return td.EMEFieldData(monitor=monitor, **kwargs, grid_expanded=grid_expanded)


def _get_eme_field2_data(num_sweep=0):
sim = make_eme_sim()
dataset = _get_eme_field2_dataset(num_sweep=num_sweep)
kwargs = dataset.field_components
monitor = td.EMEFieldMonitor(size=(td.inf, td.inf, 0), name="field2", colocate=True)
grid_expanded = sim.discretize_monitor(monitor)
return td.EMEFieldData(monitor=monitor, **kwargs, grid_expanded=grid_expanded)


def _get_eme_coeff_data(num_sweep=0):
Expand Down Expand Up @@ -955,12 +1009,14 @@ def test_eme_sim_data():
mode_monitor_data = _get_eme_mode_solver_data()
coeff_monitor_data = _get_eme_coeff_data()
field_monitor_data = _get_eme_field_data()
field2_monitor_data = _get_eme_field2_data()
modes_in_data = _get_mode_solver_data(modes_out=False, num_modes=3)
modes_out_data = _get_mode_solver_data(modes_out=True, num_modes=2)
data = [
mode_monitor_data,
coeff_monitor_data,
field_monitor_data,
field2_monitor_data,
modes_in_data,
modes_out_data,
]
Expand Down Expand Up @@ -1221,6 +1277,14 @@ def test_eme_sim_data():
field_in_basis = sim_data.field_in_basis(field=sim_data["field"], modes=modes_in0, port_index=1)
assert "mode_index" not in field_in_basis.Ex.coords

# test dot and outer dot with EME field and mode data
eme_field_data = sim_data["field2"]
mode_data = sim_data.port_modes_list_sweep[0][0]
eme_mode_data = sim_data["modes"]
datas = [eme_field_data, mode_data, eme_mode_data]
for data1, data2 in itertools.product(datas, datas):
_ = data1.outer_dot(data2)


def test_eme_sim_subsection():
eme_sim = td.EMESimulation(
Expand Down
84 changes: 72 additions & 12 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,27 @@ def _grid_correction_dict(self):
"grid_dual_correction": self.grid_dual_correction,
}

@property
def _normal_axis(self) -> int:
"""For a 2D monitor data, return the normal axis.
For an EMEModeSolverMonitor, return the propagation axis."""
# special treatment for EMEModeSolverMonitor
test_field = list(self.field_components.values())[0]
if "eme_cell_index" in test_field.coords:
for axis, dim in enumerate(list("xyz")):
if dim not in test_field.coords:
return axis
elif len(self.monitor.zero_dims) != 1:
raise DataError("Data must be 2D to get tangential dimensions.")
return self.monitor.zero_dims[0]

@property
def _tangential_dims(self) -> list[str]:
"""For a 2D monitor data, return the names of the tangential dimensions. Raise if cannot
confirm that the associated monitor is 2D."""
if len(self.monitor.zero_dims) != 1:
raise DataError("Data must be 2D to get tangential dimensions.")

tangential_dims = ["x", "y", "z"]
tangential_dims.pop(self.monitor.zero_dims[0])
tangential_dims.pop(self._normal_axis)

return tangential_dims

Expand Down Expand Up @@ -514,7 +527,7 @@ def _diff_area(self) -> DataArray:
coords = [bs.copy() for bs in self._plane_grid_centers]

# Append the first and last boundary
_, plane_inds = self.monitor.pop_axis([0, 1, 2], self.monitor.size.index(0.0))
_, plane_inds = self.monitor.pop_axis([0, 1, 2], self._normal_axis)
coords[0] = np.array([bounds[0][0], *coords[0].tolist(), bounds[0][-1]])
coords[1] = np.array([bounds[1][0], *coords[1].tolist(), bounds[1][-1]])

Expand Down Expand Up @@ -551,14 +564,11 @@ def _tangential_corrected(self, fields: dict[str, DataArray]) -> dict[str, DataA
poynting, flux, and dot-like methods. The normal coordinate is dropped from the field data.
"""

if len(self.monitor.zero_dims) != 1:
raise DataError("Data must be 2D to get tangential fields.")

# Tangential field components
tan_dims = self._tangential_dims
components = [fname + dim for fname in "EH" for dim in tan_dims]

normal_dim = "xyz"[self.monitor.size.index(0)]
normal_dim = "xyz"[self._normal_axis]

tan_fields = {}
for component in components:
Expand Down Expand Up @@ -853,7 +863,6 @@ def outer_dot(
--------
:member:`dot`
"""

tan_dims = self._tangential_dims

if not all(a == b for a, b in zip(tan_dims, field_data._tangential_dims)):
Expand Down Expand Up @@ -905,7 +914,9 @@ def outer_dot(
dim={"mode_index_1": [0]}, axis=len(fields_other[key].shape)
)

d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()
fields_self, fields_other, d_area = self._align_fields(
fields_self, fields_other, self._diff_area
)

# function to apply at each pair of mode indices before integrating
def fn(fields_1, fields_2):
Expand All @@ -922,7 +933,7 @@ def fn(fields_1, fields_2):
e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1
h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1

summand = 0.25 * (e_self_x_h_other - h_self_x_e_other) * d_area
summand = 0.25 * (e_self_x_h_other - h_self_x_e_other)
return summand

result = self._outer_fn_summation(
Expand All @@ -932,6 +943,7 @@ def fn(fields_1, fields_2):
outer_dim_2="mode_index_1",
sum_dims=tan_dims,
fn=fn,
d_area=d_area,
)

# Remove mode index coordinate if the input did not have it
Expand All @@ -942,6 +954,49 @@ def fn(fields_1, fields_2):

return result

@staticmethod
def _align_fields(
fields_1: dict[str, xr.DataArray], fields_2: dict[str, xr.DataArray], d_area: xr.DataArray
) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray], xr.DataArray]:
Comment on lines +958 to +960
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left two comments but overall I find this function to be pretty difficult to follow. Maybe we could look at breaking the function down into smaller helpers to clarify the main workflow and better distinguish between the 'dot' and 'outer_dot' cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try to clean this up, comment it, split it up.

In the end I didn't include this in dot but maybe I can take another look at that.

"""Align the fields for dot or outer_dot, inserting any missing dimensions."""
exclude_dims = ["mode_index_0", "mode_index_1", "x", "y", "z"]
for key, field_1 in fields_1.items():
field_2 = fields_2[key]
# remove non-coord dims
for dim in field_1.dims:
if dim not in field_1.coords:
field_1 = field_1.isel({dim: 0})
for dim in field_2.dims:
if dim not in field_2.coords:
field_2 = field_2.isel({dim: 0})
field_1, field_2 = xr.align(
field_1, field_2, join="inner", copy=False, exclude=exclude_dims
)
field_1, field_2 = xr.broadcast(field_1, field_2, exclude=exclude_dims)
dims1 = []
mode_index_dim1 = None
mode_index_dim2 = None
for dim in field_1.dims:
if dim == "mode_index":
mode_index_dim1 = "mode_index"
mode_index_dim2 = "mode_index"
elif dim == "mode_index_0":
mode_index_dim1 = "mode_index_0"
mode_index_dim2 = "mode_index_1"
else:
dims1.append(dim)
dims2 = list(dims1)
dims1.append(mode_index_dim1)
dims2.append(mode_index_dim2)
field_1 = field_1.transpose(*dims1)
field_2 = field_2.transpose(*dims2)
fields_1[key] = field_1
fields_2[key] = field_2
Comment on lines +993 to +994
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional that this function mutates the input? That seems potentially unsafe..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's fine with this usage but agreed not good practice unless made explicit in the function signature.

d_area, _ = xr.align(
d_area, list(fields_1.values())[0], join="inner", copy=False, exclude=exclude_dims
)
return fields_1, fields_2, d_area

@staticmethod
def _outer_fn_summation(
fields_1: dict[str, xr.DataArray],
Expand All @@ -950,6 +1005,7 @@ def _outer_fn_summation(
outer_dim_2: str,
sum_dims: list[str],
fn: Callable,
d_area: Optional[xr.DataArray] = None,
) -> DataArray:
"""
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
Expand All @@ -960,6 +1016,10 @@ def _outer_fn_summation(
# first, convert to numpy outside the loop to reduce xarray overhead
fields_1_numpy = {key: val.to_numpy() for key, val in fields_1.items()}
fields_2_numpy = {key: val.to_numpy() for key, val in fields_2.items()}
if d_area is None:
d_area_numpy = 1
else:
d_area_numpy = d_area.to_numpy()

# get one of the data arrays to look at for indexing
# assuming all data arrays have the same structure
Expand Down Expand Up @@ -1002,7 +1062,7 @@ def _outer_fn_summation(
fields_1_curr = {key: val[tuple(idx_1)] for key, val in fields_1_numpy.items()}
fields_2_curr = {key: val[tuple(idx_2)] for key, val in fields_2_numpy.items()}
summand_curr = fn(fields_1_curr, fields_2_curr)
data_curr = np.sum(summand_curr, axis=tuple(sum_axes))
data_curr = np.sum(summand_curr * d_area_numpy, axis=tuple(sum_axes))
data[tuple(idx_data)] = data_curr

return DataArray(data, coords=coords)
Expand Down