Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 94 additions & 31 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
if TYPE_CHECKING:
# Import these optional dependencies only during type checking
from cartopy.crs import CRS
from cartopy.mpl.geoaxes import GeoAxes
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib.collections import PolyCollection
from matplotlib.collections import Collection, PolyCollection
from matplotlib.figure import Figure
from matplotlib.quiver import Quiver

Expand Down Expand Up @@ -935,6 +935,7 @@ def data_crs(self) -> 'CRS':
def plot_on_figure(
self,
figure: 'Figure',
*variables: DataArrayOrName | tuple[DataArrayOrName, ...],
scalar: DataArrayOrName | None = None,
vector: tuple[DataArrayOrName, DataArrayOrName] | None = None,
title: str | None = None,
Expand Down Expand Up @@ -971,26 +972,30 @@ def plot_on_figure(
:func:`.plot.plot_on_figure` : The underlying implementation
"""
if scalar is not None:
kwargs['scalar'] = utils.name_to_data_array(self.dataset, scalar)
variables = variables + (scalar,)

if vector is not None:
kwargs['vector'] = tuple(utils.name_to_data_array(self.dataset, v) for v in vector)
variables = variables + (vector,)

mapped_variables: list[xarray.DataArray | tuple[xarray.DataArray, ...]] = []
for variable in variables:
if isinstance(variable, tuple):
mapped_variables.append(tuple(
utils.name_to_data_array(self.dataset, v)
for v in variable
))
else:
mapped_variables.append(utils.name_to_data_array(self.dataset, variable))

if title is not None:
kwargs['title'] = title
elif scalar is not None and vector is None:
# Make a title out of the scalar variable, but only if a title
# hasn't been supplied and we don't also have vectors to plot.
#
# We can't make a good name from vectors,
# as they are in two variables with names like
# 'u component of current' and 'v component of current'.
#
# Users can supply their own titles
# if this automatic behaviour is insufficient
kwargs['title'] = make_plot_title(self.dataset, kwargs['scalar'])

plot_on_figure(figure, self, **kwargs)
# Find a title if there is a single variable passed in
elif len(variables) == 1 and isinstance(variables[0], xarray.DataArray):
variable = variables[0]
kwargs['title'] = make_plot_title(self.dataset, variable)

plot_on_figure(figure, self, *mapped_variables, **kwargs)

@_requires_plot
def plot(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -1111,6 +1116,73 @@ def animate_on_figure(

return animate_on_figure(figure, self, coordinate=coordinate, **kwargs)

@_requires_plot
def plot_on_axes(
self,
axes: 'GeoAxes',
variable: xarray.DataArray | tuple[xarray.DataArray, ...],
**kwargs: Any,
) -> 'Collection':
if isinstance(variable, xarray.DataArray):
grid_kind = self.get_grid_kind(variable)
if grid_kind is self.default_grid_kind:
return self.plot_polygons_on_axes(axes, variable)

raise ValueError(f"Not able to plot variable {variable.name!r}, grid kind {grid_kind}")

else:
names = tuple(v.name for v in variable)
grid_kinds = tuple(self.get_grid_kind(v) for v in variable)
if len(variable) == 2 and grid_kinds == (self.default_grid_kind, self.default_grid_kind):
return self.plot_vector_components_on_axes(axes, variable)

raise ValueError(f"Not able to plot variables {names!r}, grid kinds {grid_kinds!r}")

@_requires_plot
def plot_polygons_on_axes(
self,
axes: 'GeoAxes',
data_array: xarray.DataArray | None,
colorbar: bool = True,
**kwargs: Any,
) -> 'Collection':
defaults = dict(cmap='jet', edgecolor='face')
kwargs = {**defaults, **kwargs}

collection = self.make_poly_collection(data_array, **kwargs)
axes.add_collection(collection)

if colorbar and data_array is not None:
units = data_array.attrs.get('units')
axes.get_figure().colorbar(collection, ax=axes, location='right', label=units)

return collection

@_requires_plot
def plot_vector_components_on_axes(
self,
axes: 'GeoAxes',
components: tuple[xarray.DataArray, xarray.DataArray],
**kwargs: Any,
) -> 'Collection':
u, v = components
quiver = self.make_quiver(axes, u, v, **kwargs)
axes.add_collection(quiver)
return quiver

@_requires_plot
def plot_geometry_on_axes(
self,
axes: 'GeoAxes',
**kwargs: Any,
) -> 'Collection':
"""
Plot the geometry of this dataset on to some Axes
"""
collection = self.make_poly_collection()
axes.add_collection(collection)
return collection

@_requires_plot
@utils.timed_func
def make_poly_collection(
Expand All @@ -1120,7 +1192,7 @@ def make_poly_collection(
) -> 'PolyCollection':
"""
Make a :class:`~matplotlib.collections.PolyCollection`
from the geometry of this :class:`~xarray.Dataset`.
from the polygon geometry of this :class:`~xarray.Dataset`.
This can be used to make custom matplotlib plots from your data.

If a :class:`~xarray.DataArray` is passed in,
Expand Down Expand Up @@ -1190,32 +1262,23 @@ def make_poly_collection(

return polygons_to_collection(self.polygons[self.mask], **kwargs)

def make_patch_collection(
self,
data_array: DataArrayOrName | None = None,
**kwargs: Any,
) -> 'PolyCollection':
warnings.warn(
"Convention.make_patch_collection has been renamed to "
"Convention.make_poly_collection, and now returns a PolyCollection",
category=DeprecationWarning,
)
return self.make_poly_collection(data_array, **kwargs)

@_requires_plot
def make_quiver(
self,
axes: 'Axes',
axes: 'GeoAxes',
u: DataArrayOrName | None = None,
v: DataArrayOrName | None = None,
**kwargs: Any,
) -> 'Quiver':
"""
Make a :class:`matplotlib.quiver.Quiver` instance to plot vector data.
The vectors will be placed at the centre of each polygon of the dataset.
The data arrays `u` and `v` represent the `x` and `y` vector components
for each polygon.

Parameters
----------
axes : matplotlib.axes.Axes
axes : matplotlib.axes.GeoAxes
The axes to make this quiver on.
u, v : xarray.DataArray or str, optional
The DataArrays or the names of DataArrays in this dataset
Expand Down
55 changes: 54 additions & 1 deletion src/emsarray/conventions/ugrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from contextlib import suppress
from dataclasses import dataclass
from functools import cached_property
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

import numpy
import shapely
Expand All @@ -25,10 +25,16 @@
from emsarray.exceptions import (
ConventionViolationError, ConventionViolationWarning
)
from emsarray.operations import triangulate
from emsarray.plot import _requires_plot
from emsarray.types import Bounds, Pathish

from ._base import DimensionConvention, Specificity

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.collections import Collection

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -1345,3 +1351,50 @@ def drop_geometry(self) -> xarray.Dataset:
dataset = super().drop_geometry()
dataset.attrs.pop('Conventions', None)
return dataset

@_requires_plot
def plot_on_axes(
self,
axes: 'Axes',
variable: xarray.DataArray | tuple[xarray.DataArray, ...],
**kwargs: Any,
) -> 'Collection':
if isinstance(variable, xarray.DataArray):
grid_kind = self.get_grid_kind(variable)
if grid_kind is UGridKind.node:
return self.plot_tripcolor_on_axes(axes, variable, **kwargs)

return super().plot_on_axes(axes, variable, **kwargs)

def plot_tripcolor_on_axes(
self,
axes: 'Axes',
data_array: xarray.DataArray,
*,
shading: str = 'gouraud',
**kwargs: Any,
) -> 'Collection':
collection = self.make_tripcolor(axes, data_array, **kwargs)
axes.add_collection(collection)
return collection

def make_tripcolor(
self,
axes: 'Axes',
data_array: xarray.DataArray | None = None,
*,
shading: str = 'gouraud',
**kwargs: Any,
) -> 'Collection':
import matplotlib.tri

topology = self.topology
vertices = numpy.c_[topology.node_x.values, topology.node_y.values].T

_vertices, triangles, _face_indices = triangulate.triangulate_dataset(self.dataset, vertices=vertices)
triangulation = matplotlib.tri.Triangulation(vertices[:, 0], vertices[:, 1], triangles)

data_array = utils.name_to_data_array(self.dataset, data_array)

collection = axes.tripcolor(triangulation, data_array.values, shading='gouraud', **kwargs)
return collection
17 changes: 12 additions & 5 deletions src/emsarray/operations/triangulate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Operations for making a triangular mesh out of the polygons of a dataset.
"""
from typing import Any

import numpy
import pandas
import shapely
Expand All @@ -14,6 +16,8 @@

def triangulate_dataset(
dataset: xarray.Dataset,
*,
vertices: numpy.ndarray[tuple[Any, ...], numpy.dtype[Any]] | None = None,
) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
"""
Triangulate the polygon cells of a dataset
Expand Down Expand Up @@ -93,11 +97,14 @@ def triangulate_dataset(
"""
polygons = dataset.ems.polygons

# Find all the unique coordinates and assign them each a unique index
all_coords = shapely.get_coordinates(polygons)
vertex_index = pandas.MultiIndex.from_arrays(all_coords.T).drop_duplicates()
if vertices is None:
# Find all the unique coordinates and assign them each a unique index
all_coords = shapely.get_coordinates(polygons)
vertex_index = pandas.MultiIndex.from_arrays(all_coords.T).drop_duplicates()
vertices = numpy.array(vertex_index.to_list())
else:
vertex_index = pandas.MultiIndex.from_arrays(list(vertices))
vertex_series = pandas.Series(numpy.arange(len(vertex_index)), index=vertex_index)
vertex_coords = numpy.array(vertex_index.to_list())

polygon_length = shapely.get_num_coordinates(polygons)

Expand Down Expand Up @@ -191,7 +198,7 @@ def _add_triangles(face_index: int, vertex_triangles: numpy.ndarray) -> None:
faces = joined_df['face_indices'].to_numpy()
triangles = joined_df[['v0', 'v1', 'v2']].to_numpy()

return vertex_coords, triangles, faces
return vertices, triangles, faces


def _triangulate_polygons_by_length(polygons: numpy.ndarray) -> numpy.ndarray:
Expand Down
53 changes: 24 additions & 29 deletions src/emsarray/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def make_plot_title(
def plot_on_figure(
figure: Figure,
convention: 'conventions.Convention',
*,
*variables: xarray.DataArray | tuple[xarray.DataArray, ...],
axes: GeoAxes | None = None,
scalar: xarray.DataArray | None = None,
vector: tuple[xarray.DataArray, xarray.DataArray] | None = None,
title: str | None = None,
Expand Down Expand Up @@ -315,31 +316,30 @@ def plot_on_figure(
coast : bool, default True
Whether to add coastlines to the plot using :func:`add_coast()`.
"""
if projection is None:
projection = cartopy.crs.PlateCarree()

axes: GeoAxes = figure.add_subplot(projection=projection)
axes.set_aspect(aspect='equal', adjustable='datalim')

if scalar is None and vector is None:
# Plot the polygon shapes for want of anything else to draw
collection = convention.make_poly_collection()
axes.add_collection(collection)
if title is None:
title = 'Geometry'

# Support old syntax of separately specifying scalar and vector items to plot.
# This will be deprecated soon.
if scalar is not None:
# Plot a scalar variable on the polygons using a colour map
collection = convention.make_poly_collection(
scalar, cmap='jet', edgecolor='face')
axes.add_collection(collection)
units = scalar.attrs.get('units')
figure.colorbar(collection, ax=axes, location='right', label=units)

variables = variables + (scalar,)
if vector is not None:
# Plot a vector variable using a quiver
quiver = convention.make_quiver(axes, *vector)
axes.add_collection(quiver)
variables = variables + (vector,)

# Construct some axes if we don't have any
if axes is None:
if projection is None:
projection = cartopy.crs.PlateCarree()
axes = figure.add_subplot(projection=projection)
axes.set_aspect(aspect='equal', adjustable='datalim')

# Plot everything we have
for var in variables:
convention.plot_on_axes(axes, var)

# Plot the geometry of the dataset if there are no variables passed in
if len(variables) == 0:
convention.plot_geometry_on_axes(axes)
if title is None:
title = 'Geometry'

if title:
axes.set_title(title)
Expand All @@ -349,17 +349,12 @@ def plot_on_figure(

if coast:
add_coast(axes)

if gridlines:
add_gridlines(axes)

axes.autoscale()

# Work around for gridline positioning issues
# https://github.com/SciTools/cartopy/issues/2245#issuecomment-1732313921
layout_engine = figure.get_layout_engine()
if layout_engine is not None:
layout_engine.execute(figure)


@_requires_plot
def animate_on_figure(
Expand Down
Loading