Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions cfspopcon/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
else:
from typing_extensions import Self # type:ignore[attr-defined,unused-ignore]

from enum import Enum

import numpy as np
import xarray as xr

from .helpers import convert_named_options
from .shaping_and_selection.point_selection import build_mask_from_dict, find_coords_of_minimum
from .shaping_and_selection.point_selection import find_values_at_nearest_point
from .unit_handling import convert_to_default_units, set_default_units

ignored_keys = [
Expand All @@ -36,15 +38,21 @@ def sanitize_variable(val: xr.DataArray, key: str, coord: bool = False) -> Union
except KeyError:
pass

def get_name(val: Enum | str) -> str:
if isinstance(val, Enum):
return val.name
else:
return val

if val.dtype == object:
try:
if val.size == 1:
if not coord:
val = val.item().name
val = get_name(val.item()) # type:ignore[assignment]
else:
val = xr.DataArray([val.item().name])
val = xr.DataArray([get_name(val.item())])
else:
val = xr.DataArray([v.name for v in val.values], dims=val.dims)
val = xr.DataArray([get_name(v) for v in val.values], dims=val.dims)
except AttributeError:
warnings.warn(f"Cannot handle {key}. Dropping variable.", stacklevel=3)
return "UNHANDLED"
Expand Down Expand Up @@ -144,19 +152,8 @@ def __exit__(self, *args: Any) -> None:

def write_point_to_file(dataset: xr.Dataset, point_key: str, point_params: dict, output_dir: Path) -> None:
"""Write the analysis values at the named points to a json file."""
mask = build_mask_from_dict(dataset, point_params)

if "minimize" not in point_params.keys() and "maximize" not in point_params.keys():
raise ValueError(f"Need to provide either minimize or maximize in point specification. Keys were {point_params.keys()}")

if "minimize" in point_params.keys():
array = dataset[point_params["minimize"]]
else:
array = -dataset[point_params["maximize"]]

point_coords = find_coords_of_minimum(array, keep_dims=point_params.get("keep_dims", []), mask=mask)
point = find_values_at_nearest_point(dataset, point_params)

point = dataset.isel(point_coords)
for key in point.keys():
if key in ignored_keys:
assert isinstance(key, str)
Expand Down
47 changes: 23 additions & 24 deletions cfspopcon/plotting/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import xarray as xr
from matplotlib.axes import Axes

from ..shaping_and_selection.point_selection import build_mask_from_dict, find_coords_of_minimum
from ..shaping_and_selection.point_selection import build_mask_from_dict, find_values_at_nearest_point
from ..shaping_and_selection.transform_coords import build_transform_function_from_dict
from ..unit_handling import Quantity, Unit, dimensionless_magnitude
from ..unit_handling import Quantity, Unit, default_unit, dimensionless_magnitude, magnitude_in_units
from .coordinate_formatter import CoordinateFormatter


Expand All @@ -24,7 +24,7 @@ def make_plot(
save_name: Optional[str] = None,
ax: Optional[Axes] = None,
output_dir: Path = Path("."),
) -> None:
):
"""Given a dictionary corresponding to a plotting style, build a standard plot from the results of the POPCON."""
if plot_params["type"] == "popcon":
if ax is None:
Expand All @@ -36,8 +36,10 @@ def make_plot(
if save_name is not None:
fig.savefig(output_dir / save_name)

return fig, ax


def make_popcon_plot(dataset: xr.Dataset, title: str, plot_params: dict, points: dict, ax: Axes):
def make_popcon_plot(dataset: xr.Dataset, title: str, plot_params: dict, points_params: dict, ax: Axes):
"""Make a plot."""
from cfspopcon import __version__

Expand Down Expand Up @@ -88,30 +90,27 @@ def make_popcon_plot(dataset: xr.Dataset, title: str, plot_params: dict, points:

legend_elements[subplot_params["label"]] = legend_entry

for key, point_params in points.items():
point_style = plot_params.get("points", dict()).get(key, dict())
for key, point_params in points_params.items():
if key not in plot_params.get("points", dict()):
continue
point_style = plot_params["points"][key]
label = point_style.get("label", key)

mask = build_mask_from_dict(dataset, point_params)

if "minimize" not in point_params.keys() and "maximize" not in point_params.keys():
raise ValueError(f"Need to provide either minimize or maximize in point specification. Keys were {point_params.keys()}")

if "minimize" in point_params.keys():
field = dataset[point_params["minimize"]]
else:
field = -dataset[point_params["maximize"]]
point = find_values_at_nearest_point(dataset, point_params)

transformed_field = transform_func(field.where(mask))
plotting_coords = []
for coord in [coords["x"], coords["y"]]:
dimension_name = coord["dimension"]

point_coords = find_coords_of_minimum(transformed_field, keep_dims=point_params.get("keep_dims", []))
if dimension_name not in point.coords and f"dim_{dimension_name}" in point.coords:
dimension_name = f"dim_{dimension_name}"

point = transformed_field.isel(point_coords)
plotting_coords = []
for dim in [coords["x"]["dimension"], coords["y"]["dimension"]]:
if dim not in point.coords and f"dim_{dim}" in point.coords:
dim = f"dim_{dim}" # noqa: PLW2901
plotting_coords.append(point[dim])
requested_units = coord.get("units", "")
if hasattr(point[dimension_name].pint, "units") and point[dimension_name].pint.units is not None:
plotting_coords.append(magnitude_in_units(point[dimension_name], requested_units))
else:
default_units = Quantity(1.0, default_unit(dimension_name.lstrip("dim_")))
plotting_coords.append(magnitude_in_units(point[dimension_name] * default_units, requested_units))

legend_elements[label] = ax.scatter(
*plotting_coords,
Expand All @@ -123,7 +122,7 @@ def make_popcon_plot(dataset: xr.Dataset, title: str, plot_params: dict, points:
ax.set_title(f"{title} [{__version__}]")
ax.set_xlabel(coords["x"]["label"])
ax.set_ylabel(coords["y"]["label"])
ax.legend(legend_elements.values(), legend_elements.keys())
ax.legend(legend_elements.values(), legend_elements.keys(), loc=plot_params.get("legend_loc", "best"))
plt.tight_layout()

return fig, ax
Expand Down
4 changes: 4 additions & 0 deletions cfspopcon/shaping_and_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
build_mask_from_dict,
find_coords_of_maximum,
find_coords_of_minimum,
find_coords_of_nearest_point,
find_values_at_nearest_point,
)
from .transform_coords import (
build_transform_function_from_dict,
Expand All @@ -21,6 +23,8 @@
"find_coords_of_contour",
"find_coords_of_maximum",
"find_coords_of_minimum",
"find_coords_of_nearest_point",
"find_values_at_nearest_point",
"interpolate_array_onto_new_coords",
"interpolate_onto_line",
"order_dimensions",
Expand Down
124 changes: 123 additions & 1 deletion cfspopcon/shaping_and_selection/point_selection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,135 @@
"""Routines to find the coordinates of the minimum or maximum value of a field."""

import warnings
from collections.abc import Sequence
from typing import Optional

import numpy as np
import xarray as xr
from xarray.core.coordinates import DataArrayCoordinates

from ..unit_handling import Quantity
from ..unit_handling import Quantity, dimensionless_magnitude, magnitude_in_default_units


def find_values_at_nearest_point(dataset: xr.Dataset, point_params: dict) -> xr.Dataset:
"""Return a dataset at a point point which best fulfills the conditions defined by point params."""
allowed_methods = ["minimize", "maximize", "nearest_to", "interp_to"]

method = [method for method in allowed_methods if method in point_params.keys()]
assert len(method) == 1, f"Must provide one of [{', '.join(allowed_methods)}] for a point. Keys were {list(point_params.keys())}"

if method[0] == "interp_to":
mask = build_mask_from_dict(dataset, point_params)

requested_coords = dict()

for dimension_name, request in point_params["interp_to"].items():
if dimension_name not in dataset.coords and f"dim_{dimension_name}" in dataset.coords:
dimension_name = f"dim_{dimension_name}" # noqa: PLW2901

assert dimension_name in dataset.coords, (
f"Cannot interpolate to {dimension_name} since it is not in the dataset coordinates {dataset.coords}."
)

value = Quantity(float(request["value"]), request.get("units", ""))
requested_coords[dimension_name] = magnitude_in_default_units(value, key=dimension_name.lstrip("dim_"))

with warnings.catch_warnings():
warnings.simplefilter("ignore")
# We dequantify and then requantify the dataset, since some interpolation methods cannot handle units
return ( # type:ignore[no-any-return]
dataset.where(mask)
.pint.dequantify()
.interp(**requested_coords, method=point_params.get("method", "linear"))
.pint.quantify()
)

elif method[0] in ["minimize", "maximize", "nearest_to"]:
return dataset.isel(find_coords_of_nearest_point(dataset, point_params))
else:
raise NotImplementedError(f"{method[0]} not recognized.")


def find_coords_of_nearest_point(dataset: xr.Dataset, point_params: dict) -> DataArrayCoordinates:
"""Find the coordinates of a point which best fulfills the conditions defined by point params.

The point parameters must have a 'minimize', 'maximize' or 'nearest_to' key.

'point_name': {
'maximize': 'Q'
}
will find the point with the highest 'Q' value

'point_name': {
'nearest_to': {
'average_electron_density': {
'value': 20.0,
'units': 'n19'
},
'max_flattop_duration': {
'value': 2.0,
'norm': 1.0,
'units': 's'
},
},
'tolerance': 1e-2,
}
will find the point which minimizes
d = sqrt(
((average_electron_density - (20 * n19)) / (20 * n19))**2
+ ((max_flattop_duration - (2 * s)) / (1 * s))**2
)
If the resulting point has d > 1e-2 (tolerance), an AssertionError will be raised.

'nearest_to' is intended to return the dataset at a given grid point. However, you
can use it to find points fulfilling non-scanned conditions (as in the example above).

A mask can also be provided.
'point_name': {
'maximize': 'Q',
'where': {
'P_auxiliary_launched': {
'min': 0.0,
'max': 25.0,
'units': 'MW',
},
'greenwald_fraction': {
'max': 0.9
}
}
}
will find the maximum value of Q in the region with P_auxiliary_launched between 0 and 25MW, and with a Greenwald
fraction up to 90%.
"""
allowed_methods = ["minimize", "maximize", "nearest_to"]

method = [method for method in allowed_methods if method in point_params.keys()]
assert len(method) == 1, f"Must provide one of [{', '.join(allowed_methods)}] for a point. Keys were {list(point_params.keys())}"

mask = build_mask_from_dict(dataset, point_params)

if method[0] == "minimize":
coords_of_point = find_coords_of_minimum(dataset[point_params["minimize"]], keep_dims=point_params.get("keep_dims", []), mask=mask)
elif method[0] == "maximize":
coords_of_point = find_coords_of_maximum(dataset[point_params["maximize"]], keep_dims=point_params.get("keep_dims", []), mask=mask)
elif method[0] == "nearest_to":
normalized_distance_squared = []
for variable_name, request in point_params["nearest_to"].items():
value = Quantity(float(request["value"]), request.get("units", ""))
normalization = Quantity(float(request.get("norm", request["value"])), request.get("units", ""))
normalized_distance_squared.append(dimensionless_magnitude(((dataset[variable_name] - value) / normalization) ** 2))

euclidean_distance = np.sqrt(xr.concat(xr.broadcast(*normalized_distance_squared), dim="dim_distance").sum(dim="dim_distance")) # type:ignore[call-overload]
coords_of_point = find_coords_of_minimum(euclidean_distance, keep_dims=point_params.get("keep_dims", []), mask=mask)

if "tolerance" in point_params:
assert np.all((tol := euclidean_distance.isel(coords_of_point)) < point_params["tolerance"]), (
f"Normalized distance at nearest point [{tol.values}] is greater than the requested tolerance [{point_params['tolerance']}]"
)
else:
raise NotImplementedError(f"{method[0]} not recognized.")

return coords_of_point


def find_coords_of_minimum(array: xr.DataArray, keep_dims: Sequence[str] = [], mask: Optional[xr.DataArray] = None) -> DataArrayCoordinates:
Expand Down
2 changes: 2 additions & 0 deletions example_cases/SPARC_PRD/plot_popcon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ type: popcon

figsize: [8, 6]
show_dpi: 150
legend_loc: "upper right"

coords:
x:
Expand All @@ -15,6 +16,7 @@ coords:

fill:
variable: Q
cbar_label: Q
where:
Q:
min: 1.0
Expand Down