diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index e35c38f1..22ef190c 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -30,9 +30,9 @@ if TYPE_CHECKING: # Import these optional dependencies only during type checking from cartopy.crs import CRS + from cartopy.mpl.geoaxes import GeoAxes from matplotlib.animation import FuncAnimation - from matplotlib.axes import Axes - from matplotlib.collections import PolyCollection + from matplotlib.collections import Collection, PolyCollection from matplotlib.figure import Figure from matplotlib.quiver import Quiver @@ -935,6 +935,7 @@ def data_crs(self) -> 'CRS': def plot_on_figure( self, figure: 'Figure', + *variables: DataArrayOrName | tuple[DataArrayOrName, ...], scalar: DataArrayOrName | None = None, vector: tuple[DataArrayOrName, DataArrayOrName] | None = None, title: str | None = None, @@ -971,26 +972,30 @@ def plot_on_figure( :func:`.plot.plot_on_figure` : The underlying implementation """ if scalar is not None: - kwargs['scalar'] = utils.name_to_data_array(self.dataset, scalar) + variables = variables + (scalar,) if vector is not None: - kwargs['vector'] = tuple(utils.name_to_data_array(self.dataset, v) for v in vector) + variables = variables + (vector,) + + mapped_variables: list[xarray.DataArray | tuple[xarray.DataArray, ...]] = [] + for variable in variables: + if isinstance(variable, tuple): + mapped_variables.append(tuple( + utils.name_to_data_array(self.dataset, v) + for v in variable + )) + else: + mapped_variables.append(utils.name_to_data_array(self.dataset, variable)) if title is not None: kwargs['title'] = title - elif scalar is not None and vector is None: - # Make a title out of the scalar variable, but only if a title - # hasn't been supplied and we don't also have vectors to plot. - # - # We can't make a good name from vectors, - # as they are in two variables with names like - # 'u component of current' and 'v component of current'. - # - # Users can supply their own titles - # if this automatic behaviour is insufficient - kwargs['title'] = make_plot_title(self.dataset, kwargs['scalar']) - plot_on_figure(figure, self, **kwargs) + # Find a title if there is a single variable passed in + elif len(variables) == 1 and isinstance(variables[0], xarray.DataArray): + variable = variables[0] + kwargs['title'] = make_plot_title(self.dataset, variable) + + plot_on_figure(figure, self, *mapped_variables, **kwargs) @_requires_plot def plot(self, *args: Any, **kwargs: Any) -> None: @@ -1111,6 +1116,73 @@ def animate_on_figure( return animate_on_figure(figure, self, coordinate=coordinate, **kwargs) + @_requires_plot + def plot_on_axes( + self, + axes: 'GeoAxes', + variable: xarray.DataArray | tuple[xarray.DataArray, ...], + **kwargs: Any, + ) -> 'Collection': + if isinstance(variable, xarray.DataArray): + grid_kind = self.get_grid_kind(variable) + if grid_kind is self.default_grid_kind: + return self.plot_polygons_on_axes(axes, variable) + + raise ValueError(f"Not able to plot variable {variable.name!r}, grid kind {grid_kind}") + + else: + names = tuple(v.name for v in variable) + grid_kinds = tuple(self.get_grid_kind(v) for v in variable) + if len(variable) == 2 and grid_kinds == (self.default_grid_kind, self.default_grid_kind): + return self.plot_vector_components_on_axes(axes, variable) + + raise ValueError(f"Not able to plot variables {names!r}, grid kinds {grid_kinds!r}") + + @_requires_plot + def plot_polygons_on_axes( + self, + axes: 'GeoAxes', + data_array: xarray.DataArray | None, + colorbar: bool = True, + **kwargs: Any, + ) -> 'Collection': + defaults = dict(cmap='jet', edgecolor='face') + kwargs = {**defaults, **kwargs} + + collection = self.make_poly_collection(data_array, **kwargs) + axes.add_collection(collection) + + if colorbar and data_array is not None: + units = data_array.attrs.get('units') + axes.get_figure().colorbar(collection, ax=axes, location='right', label=units) + + return collection + + @_requires_plot + def plot_vector_components_on_axes( + self, + axes: 'GeoAxes', + components: tuple[xarray.DataArray, xarray.DataArray], + **kwargs: Any, + ) -> 'Collection': + u, v = components + quiver = self.make_quiver(axes, u, v, **kwargs) + axes.add_collection(quiver) + return quiver + + @_requires_plot + def plot_geometry_on_axes( + self, + axes: 'GeoAxes', + **kwargs: Any, + ) -> 'Collection': + """ + Plot the geometry of this dataset on to some Axes + """ + collection = self.make_poly_collection() + axes.add_collection(collection) + return collection + @_requires_plot @utils.timed_func def make_poly_collection( @@ -1120,7 +1192,7 @@ def make_poly_collection( ) -> 'PolyCollection': """ Make a :class:`~matplotlib.collections.PolyCollection` - from the geometry of this :class:`~xarray.Dataset`. + from the polygon geometry of this :class:`~xarray.Dataset`. This can be used to make custom matplotlib plots from your data. If a :class:`~xarray.DataArray` is passed in, @@ -1190,32 +1262,23 @@ def make_poly_collection( return polygons_to_collection(self.polygons[self.mask], **kwargs) - def make_patch_collection( - self, - data_array: DataArrayOrName | None = None, - **kwargs: Any, - ) -> 'PolyCollection': - warnings.warn( - "Convention.make_patch_collection has been renamed to " - "Convention.make_poly_collection, and now returns a PolyCollection", - category=DeprecationWarning, - ) - return self.make_poly_collection(data_array, **kwargs) - @_requires_plot def make_quiver( self, - axes: 'Axes', + axes: 'GeoAxes', u: DataArrayOrName | None = None, v: DataArrayOrName | None = None, **kwargs: Any, ) -> 'Quiver': """ Make a :class:`matplotlib.quiver.Quiver` instance to plot vector data. + The vectors will be placed at the centre of each polygon of the dataset. + The data arrays `u` and `v` represent the `x` and `y` vector components + for each polygon. Parameters ---------- - axes : matplotlib.axes.Axes + axes : matplotlib.axes.GeoAxes The axes to make this quiver on. u, v : xarray.DataArray or str, optional The DataArrays or the names of DataArrays in this dataset diff --git a/src/emsarray/conventions/ugrid.py b/src/emsarray/conventions/ugrid.py index e3675a09..0306b00b 100644 --- a/src/emsarray/conventions/ugrid.py +++ b/src/emsarray/conventions/ugrid.py @@ -14,7 +14,7 @@ from contextlib import suppress from dataclasses import dataclass from functools import cached_property -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import numpy import shapely @@ -25,10 +25,16 @@ from emsarray.exceptions import ( ConventionViolationError, ConventionViolationWarning ) +from emsarray.operations import triangulate +from emsarray.plot import _requires_plot from emsarray.types import Bounds, Pathish from ._base import DimensionConvention, Specificity +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.collections import Collection + logger = logging.getLogger(__name__) @@ -1345,3 +1351,50 @@ def drop_geometry(self) -> xarray.Dataset: dataset = super().drop_geometry() dataset.attrs.pop('Conventions', None) return dataset + + @_requires_plot + def plot_on_axes( + self, + axes: 'Axes', + variable: xarray.DataArray | tuple[xarray.DataArray, ...], + **kwargs: Any, + ) -> 'Collection': + if isinstance(variable, xarray.DataArray): + grid_kind = self.get_grid_kind(variable) + if grid_kind is UGridKind.node: + return self.plot_tripcolor_on_axes(axes, variable, **kwargs) + + return super().plot_on_axes(axes, variable, **kwargs) + + def plot_tripcolor_on_axes( + self, + axes: 'Axes', + data_array: xarray.DataArray, + *, + shading: str = 'gouraud', + **kwargs: Any, + ) -> 'Collection': + collection = self.make_tripcolor(axes, data_array, **kwargs) + axes.add_collection(collection) + return collection + + def make_tripcolor( + self, + axes: 'Axes', + data_array: xarray.DataArray | None = None, + *, + shading: str = 'gouraud', + **kwargs: Any, + ) -> 'Collection': + import matplotlib.tri + + topology = self.topology + vertices = numpy.c_[topology.node_x.values, topology.node_y.values].T + + _vertices, triangles, _face_indices = triangulate.triangulate_dataset(self.dataset, vertices=vertices) + triangulation = matplotlib.tri.Triangulation(vertices[:, 0], vertices[:, 1], triangles) + + data_array = utils.name_to_data_array(self.dataset, data_array) + + collection = axes.tripcolor(triangulation, data_array.values, shading='gouraud', **kwargs) + return collection diff --git a/src/emsarray/operations/triangulate.py b/src/emsarray/operations/triangulate.py index b0a2bcc4..aa47b318 100644 --- a/src/emsarray/operations/triangulate.py +++ b/src/emsarray/operations/triangulate.py @@ -1,6 +1,8 @@ """ Operations for making a triangular mesh out of the polygons of a dataset. """ +from typing import Any + import numpy import pandas import shapely @@ -14,6 +16,8 @@ def triangulate_dataset( dataset: xarray.Dataset, + *, + vertices: numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]] | None = None, ) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: """ Triangulate the polygon cells of a dataset @@ -93,11 +97,14 @@ def triangulate_dataset( """ polygons = dataset.ems.polygons - # Find all the unique coordinates and assign them each a unique index - all_coords = shapely.get_coordinates(polygons) - vertex_index = pandas.MultiIndex.from_arrays(all_coords.T).drop_duplicates() + if vertices is None: + # Find all the unique coordinates and assign them each a unique index + all_coords = shapely.get_coordinates(polygons) + vertex_index = pandas.MultiIndex.from_arrays(all_coords.T).drop_duplicates() + vertices = numpy.array(vertex_index.to_list()) + else: + vertex_index = pandas.MultiIndex.from_arrays(list(vertices)) vertex_series = pandas.Series(numpy.arange(len(vertex_index)), index=vertex_index) - vertex_coords = numpy.array(vertex_index.to_list()) polygon_length = shapely.get_num_coordinates(polygons) @@ -191,7 +198,7 @@ def _add_triangles(face_index: int, vertex_triangles: numpy.ndarray) -> None: faces = joined_df['face_indices'].to_numpy() triangles = joined_df[['v0', 'v1', 'v2']].to_numpy() - return vertex_coords, triangles, faces + return vertices, triangles, faces def _triangulate_polygons_by_length(polygons: numpy.ndarray) -> numpy.ndarray: diff --git a/src/emsarray/plot.py b/src/emsarray/plot.py index 7acf2864..708738c8 100644 --- a/src/emsarray/plot.py +++ b/src/emsarray/plot.py @@ -274,7 +274,8 @@ def make_plot_title( def plot_on_figure( figure: Figure, convention: 'conventions.Convention', - *, + *variables: xarray.DataArray | tuple[xarray.DataArray, ...], + axes: GeoAxes | None = None, scalar: xarray.DataArray | None = None, vector: tuple[xarray.DataArray, xarray.DataArray] | None = None, title: str | None = None, @@ -315,31 +316,30 @@ def plot_on_figure( coast : bool, default True Whether to add coastlines to the plot using :func:`add_coast()`. """ - if projection is None: - projection = cartopy.crs.PlateCarree() - - axes: GeoAxes = figure.add_subplot(projection=projection) - axes.set_aspect(aspect='equal', adjustable='datalim') - - if scalar is None and vector is None: - # Plot the polygon shapes for want of anything else to draw - collection = convention.make_poly_collection() - axes.add_collection(collection) - if title is None: - title = 'Geometry' + # Support old syntax of separately specifying scalar and vector items to plot. + # This will be deprecated soon. if scalar is not None: - # Plot a scalar variable on the polygons using a colour map - collection = convention.make_poly_collection( - scalar, cmap='jet', edgecolor='face') - axes.add_collection(collection) - units = scalar.attrs.get('units') - figure.colorbar(collection, ax=axes, location='right', label=units) - + variables = variables + (scalar,) if vector is not None: - # Plot a vector variable using a quiver - quiver = convention.make_quiver(axes, *vector) - axes.add_collection(quiver) + variables = variables + (vector,) + + # Construct some axes if we don't have any + if axes is None: + if projection is None: + projection = cartopy.crs.PlateCarree() + axes = figure.add_subplot(projection=projection) + axes.set_aspect(aspect='equal', adjustable='datalim') + + # Plot everything we have + for var in variables: + convention.plot_on_axes(axes, var) + + # Plot the geometry of the dataset if there are no variables passed in + if len(variables) == 0: + convention.plot_geometry_on_axes(axes) + if title is None: + title = 'Geometry' if title: axes.set_title(title) @@ -349,17 +349,12 @@ def plot_on_figure( if coast: add_coast(axes) + if gridlines: add_gridlines(axes) axes.autoscale() - # Work around for gridline positioning issues - # https://github.com/SciTools/cartopy/issues/2245#issuecomment-1732313921 - layout_engine = figure.get_layout_engine() - if layout_engine is not None: - layout_engine.execute(figure) - @_requires_plot def animate_on_figure(