-
Notifications
You must be signed in to change notification settings - Fork 63
Add support for EME monitor data in outer_dot #2683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -438,14 +438,27 @@ def _grid_correction_dict(self): | |
"grid_dual_correction": self.grid_dual_correction, | ||
} | ||
|
||
@property | ||
def _normal_axis(self) -> int: | ||
"""For a 2D monitor data, return the normal axis. | ||
For an EMEModeSolverMonitor, return the propagation axis.""" | ||
# special treatment for EMEModeSolverMonitor | ||
test_field = list(self.field_components.values())[0] | ||
if "eme_cell_index" in test_field.coords: | ||
for axis, dim in enumerate(list("xyz")): | ||
if dim not in test_field.coords: | ||
return axis | ||
elif len(self.monitor.zero_dims) != 1: | ||
raise DataError("Data must be 2D to get tangential dimensions.") | ||
return self.monitor.zero_dims[0] | ||
|
||
@property | ||
def _tangential_dims(self) -> list[str]: | ||
"""For a 2D monitor data, return the names of the tangential dimensions. Raise if cannot | ||
confirm that the associated monitor is 2D.""" | ||
if len(self.monitor.zero_dims) != 1: | ||
raise DataError("Data must be 2D to get tangential dimensions.") | ||
|
||
tangential_dims = ["x", "y", "z"] | ||
tangential_dims.pop(self.monitor.zero_dims[0]) | ||
tangential_dims.pop(self._normal_axis) | ||
|
||
return tangential_dims | ||
|
||
|
@@ -514,7 +527,7 @@ def _diff_area(self) -> DataArray: | |
coords = [bs.copy() for bs in self._plane_grid_centers] | ||
|
||
# Append the first and last boundary | ||
_, plane_inds = self.monitor.pop_axis([0, 1, 2], self.monitor.size.index(0.0)) | ||
_, plane_inds = self.monitor.pop_axis([0, 1, 2], self._normal_axis) | ||
coords[0] = np.array([bounds[0][0], *coords[0].tolist(), bounds[0][-1]]) | ||
coords[1] = np.array([bounds[1][0], *coords[1].tolist(), bounds[1][-1]]) | ||
|
||
|
@@ -551,14 +564,11 @@ def _tangential_corrected(self, fields: dict[str, DataArray]) -> dict[str, DataA | |
poynting, flux, and dot-like methods. The normal coordinate is dropped from the field data. | ||
""" | ||
|
||
if len(self.monitor.zero_dims) != 1: | ||
raise DataError("Data must be 2D to get tangential fields.") | ||
|
||
# Tangential field components | ||
tan_dims = self._tangential_dims | ||
components = [fname + dim for fname in "EH" for dim in tan_dims] | ||
|
||
normal_dim = "xyz"[self.monitor.size.index(0)] | ||
normal_dim = "xyz"[self._normal_axis] | ||
|
||
tan_fields = {} | ||
for component in components: | ||
|
@@ -853,7 +863,6 @@ def outer_dot( | |
-------- | ||
:member:`dot` | ||
""" | ||
|
||
tan_dims = self._tangential_dims | ||
|
||
if not all(a == b for a, b in zip(tan_dims, field_data._tangential_dims)): | ||
|
@@ -905,7 +914,9 @@ def outer_dot( | |
dim={"mode_index_1": [0]}, axis=len(fields_other[key].shape) | ||
) | ||
|
||
d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy() | ||
fields_self, fields_other, d_area = self._align_fields( | ||
fields_self, fields_other, self._diff_area | ||
) | ||
|
||
# function to apply at each pair of mode indices before integrating | ||
def fn(fields_1, fields_2): | ||
|
@@ -922,7 +933,7 @@ def fn(fields_1, fields_2): | |
e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1 | ||
h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1 | ||
|
||
summand = 0.25 * (e_self_x_h_other - h_self_x_e_other) * d_area | ||
summand = 0.25 * (e_self_x_h_other - h_self_x_e_other) | ||
return summand | ||
|
||
result = self._outer_fn_summation( | ||
|
@@ -932,6 +943,7 @@ def fn(fields_1, fields_2): | |
outer_dim_2="mode_index_1", | ||
sum_dims=tan_dims, | ||
fn=fn, | ||
d_area=d_area, | ||
) | ||
|
||
# Remove mode index coordinate if the input did not have it | ||
|
@@ -942,6 +954,49 @@ def fn(fields_1, fields_2): | |
|
||
return result | ||
|
||
@staticmethod | ||
def _align_fields( | ||
fields_1: dict[str, xr.DataArray], fields_2: dict[str, xr.DataArray], d_area: xr.DataArray | ||
) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray], xr.DataArray]: | ||
"""Align the fields for dot or outer_dot, inserting any missing dimensions.""" | ||
exclude_dims = ["mode_index_0", "mode_index_1", "x", "y", "z"] | ||
for key, field_1 in fields_1.items(): | ||
field_2 = fields_2[key] | ||
# remove non-coord dims | ||
for dim in field_1.dims: | ||
if dim not in field_1.coords: | ||
field_1 = field_1.isel({dim: 0}) | ||
for dim in field_2.dims: | ||
if dim not in field_2.coords: | ||
field_2 = field_2.isel({dim: 0}) | ||
field_1, field_2 = xr.align( | ||
field_1, field_2, join="inner", copy=False, exclude=exclude_dims | ||
) | ||
field_1, field_2 = xr.broadcast(field_1, field_2, exclude=exclude_dims) | ||
dims1 = [] | ||
mode_index_dim1 = None | ||
mode_index_dim2 = None | ||
for dim in field_1.dims: | ||
if dim == "mode_index": | ||
mode_index_dim1 = "mode_index" | ||
mode_index_dim2 = "mode_index" | ||
elif dim == "mode_index_0": | ||
mode_index_dim1 = "mode_index_0" | ||
mode_index_dim2 = "mode_index_1" | ||
else: | ||
dims1.append(dim) | ||
dims2 = list(dims1) | ||
dims1.append(mode_index_dim1) | ||
dims2.append(mode_index_dim2) | ||
field_1 = field_1.transpose(*dims1) | ||
field_2 = field_2.transpose(*dims2) | ||
yaugenst-flex marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fields_1[key] = field_1 | ||
fields_2[key] = field_2 | ||
Comment on lines
+993
to
+994
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it intentional that this function mutates the input? That seems potentially unsafe.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. It's fine with this usage but agreed not good practice unless made explicit in the function signature. |
||
d_area, _ = xr.align( | ||
d_area, list(fields_1.values())[0], join="inner", copy=False, exclude=exclude_dims | ||
) | ||
return fields_1, fields_2, d_area | ||
|
||
@staticmethod | ||
def _outer_fn_summation( | ||
fields_1: dict[str, xr.DataArray], | ||
|
@@ -950,6 +1005,7 @@ def _outer_fn_summation( | |
outer_dim_2: str, | ||
sum_dims: list[str], | ||
fn: Callable, | ||
d_area: Optional[xr.DataArray] = None, | ||
) -> DataArray: | ||
""" | ||
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``. | ||
|
@@ -960,6 +1016,10 @@ def _outer_fn_summation( | |
# first, convert to numpy outside the loop to reduce xarray overhead | ||
fields_1_numpy = {key: val.to_numpy() for key, val in fields_1.items()} | ||
fields_2_numpy = {key: val.to_numpy() for key, val in fields_2.items()} | ||
if d_area is None: | ||
d_area_numpy = 1 | ||
else: | ||
d_area_numpy = d_area.to_numpy() | ||
|
||
# get one of the data arrays to look at for indexing | ||
# assuming all data arrays have the same structure | ||
|
@@ -1002,7 +1062,7 @@ def _outer_fn_summation( | |
fields_1_curr = {key: val[tuple(idx_1)] for key, val in fields_1_numpy.items()} | ||
fields_2_curr = {key: val[tuple(idx_2)] for key, val in fields_2_numpy.items()} | ||
summand_curr = fn(fields_1_curr, fields_2_curr) | ||
data_curr = np.sum(summand_curr, axis=tuple(sum_axes)) | ||
data_curr = np.sum(summand_curr * d_area_numpy, axis=tuple(sum_axes)) | ||
data[tuple(idx_data)] = data_curr | ||
|
||
return DataArray(data, coords=coords) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left two comments but overall I find this function to be pretty difficult to follow. Maybe we could look at breaking the function down into smaller helpers to clarify the main workflow and better distinguish between the 'dot' and 'outer_dot' cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can try to clean this up, comment it, split it up.
In the end I didn't include this in
dot
but maybe I can take another look at that.