Skip to content

Commit 59c2cc6

Browse files
authored
Merge pull request #77 from csiro-coasts/polycollection
Use PolyCollection over PatchCollection in matplotlib plots
2 parents 91a1bd4 + 75dd7df commit 59c2cc6

File tree

5 files changed

+81
-78
lines changed

5 files changed

+81
-78
lines changed

docs/releases/development.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ Next release (in development)
44

55
* Fix an issue with negative coordinates in :func:`~emsarray.cli.utils.bounds_argument` (:pr:`74`).
66
* Add a new ``emsarray plot`` subcommand to the ``emsarray`` command line interface (:pr:`76`).
7+
* Use :class:`matplotlib.collections.PolyCollection`
8+
rather than :class:`~matplotlib.collections.PatchCollection`
9+
for significant speed improvements
10+
(:pr:`77`).

src/emsarray/conventions/_base.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import dataclasses
55
import enum
66
import logging
7+
import warnings
78
from functools import cached_property
89
from typing import (
910
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, List,
@@ -19,8 +20,7 @@
1920
from emsarray.compat.shapely import SpatialIndex
2021
from emsarray.operations import depth
2122
from emsarray.plot import (
22-
_requires_plot, animate_on_figure, plot_on_figure,
23-
polygons_to_patch_collection
23+
_requires_plot, animate_on_figure, plot_on_figure, polygons_to_collection
2424
)
2525
from emsarray.state import State
2626
from emsarray.types import Pathish
@@ -30,7 +30,7 @@
3030
from cartopy.crs import CRS
3131
from matplotlib.animation import FuncAnimation
3232
from matplotlib.axes import Axes
33-
from matplotlib.collections import PatchCollection
33+
from matplotlib.collections import PolyCollection
3434
from matplotlib.figure import Figure
3535
from matplotlib.quiver import Quiver
3636

@@ -552,7 +552,7 @@ def data_crs(self) -> CRS:
552552
"""
553553
The coordinate reference system that coordinates in this dataset are
554554
defined in.
555-
Used by :meth:`.make_patch_collection` and :meth:`.make_quiver`.
555+
Used by :meth:`.make_poly_collection` and :meth:`.make_quiver`.
556556
Defaults to :class:`cartopy.crs.PlateCarree`.
557557
"""
558558
# Lazily imported here as cartopy is an optional dependency
@@ -746,35 +746,35 @@ def animate_on_figure(
746746
return animate_on_figure(figure, self, coordinate=coordinate, **kwargs)
747747

748748
@_requires_plot
749-
def make_patch_collection(
749+
def make_poly_collection(
750750
self,
751751
data_array: Optional[DataArrayOrName] = None,
752752
**kwargs: Any,
753-
) -> PatchCollection:
753+
) -> PolyCollection:
754754
"""
755-
Make a :class:`~matplotlib.collections.PatchCollection`
755+
Make a :class:`~matplotlib.collections.PolyCollection`
756756
from the geometry of this :class:`~xarray.Dataset`.
757757
This can be used to make custom matplotlib plots from your data.
758758
759759
If a :class:`~xarray.DataArray` is passed in,
760-
the values of that are assigned to the PatchCollection `array` parameter.
760+
the values of that are assigned to the PolyCollection `array` parameter.
761761
762762
Parameters
763763
----------
764764
data_array : Hashable or :class:`xarray.DataArray`, optional
765765
A data array, or the name of a data variable in this dataset. Optional.
766766
If given, the data array is :meth:`linearised <.make_linear>`
767-
and passed to :meth:`PatchCollection.set_array() <matplotlib.cm.ScalarMappable.set_array>`.
767+
and passed to :meth:`PolyCollection.set_array() <matplotlib.cm.ScalarMappable.set_array>`.
768768
The data is used to colour the patches.
769769
Refer to the matplotlib documentation for more information on styling.
770770
**kwargs
771771
Any keyword arguments are passed to the
772-
:class:`~matplotlib.collections.PatchCollection` constructor.
772+
:class:`~matplotlib.collections.PolyCollection` constructor.
773773
774774
Returns
775775
-------
776-
:class:`~matplotlib.collections.PatchCollection`
777-
A PatchCollection constructed using the geometry of this dataset.
776+
:class:`~matplotlib.collections.PolyCollection`
777+
A PolyCollection constructed using the geometry of this dataset.
778778
779779
Example
780780
-------
@@ -791,7 +791,7 @@ def make_patch_collection(
791791
792792
ds = emsarray.open_dataset("./tests/datasets/ugrid_mesh2d.nc")
793793
ds = ds.isel(record=0, Mesh2_layers=-1)
794-
patches = ds.ems.make_patch_collection('temp')
794+
patches = ds.ems.make_poly_collection('temp')
795795
axes.add_collection(patches)
796796
figure.colorbar(patches, ax=axes, location='right', label='meters')
797797
@@ -802,7 +802,7 @@ def make_patch_collection(
802802
if data_array is not None:
803803
if 'array' in kwargs:
804804
raise TypeError(
805-
"Can not pass both `data_array` and `array` to make_patch_collection"
805+
"Can not pass both `data_array` and `array` to make_poly_collection"
806806
)
807807

808808
data_array = self._get_data_array(data_array)
@@ -821,7 +821,19 @@ def make_patch_collection(
821821
if 'transform' not in kwargs:
822822
kwargs['transform'] = self.data_crs
823823

824-
return polygons_to_patch_collection(self.polygons[self.mask], **kwargs)
824+
return polygons_to_collection(self.polygons[self.mask], **kwargs)
825+
826+
def make_patch_collection(
827+
self,
828+
data_array: Optional[DataArrayOrName] = None,
829+
**kwargs: Any,
830+
) -> PolyCollection:
831+
warnings.warn(
832+
"Convention.make_patch_collection has been renamed to "
833+
"Convention.make_poly_collection, and now returns a PolyCollection",
834+
category=DeprecationWarning,
835+
)
836+
return self.make_poly_collection(data_array, **kwargs)
825837

826838
@_requires_plot
827839
def make_quiver(

src/emsarray/plot.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import cartopy.crs
1818
from cartopy.feature import GSHHSFeature
1919
from cartopy.mpl import gridliner
20-
from matplotlib import animation, patches
20+
from matplotlib import animation
2121
from matplotlib.artist import Artist
2222
from matplotlib.axes import Axes
23-
from matplotlib.collections import PatchCollection
23+
from matplotlib.collections import PolyCollection
2424
from matplotlib.figure import Figure
2525
from shapely.geometry import Polygon
2626
CAN_PLOT = True
@@ -30,7 +30,7 @@
3030
IMPORT_EXCEPTION = exc
3131

3232

33-
__all___ = ['CAN_PLOT', 'plot_on_figure', 'polygon_to_patch']
33+
__all___ = ['CAN_PLOT', 'plot_on_figure', 'polygons_to_collection']
3434

3535

3636
_requires_plot = requires_extra(extra='plot', import_error=IMPORT_EXCEPTION)
@@ -81,7 +81,7 @@ def bounds_to_extent(bounds: Tuple[float, float, float, float]) -> List[float]:
8181
8282
import cartopy.crs as ccrs
8383
import matplotlib.pyplot as plt
84-
from emsarray.plot import bounds_to_extent, polygon_to_patch
84+
from emsarray.plot import bounds_to_extent
8585
from shapely.geometry import Polygon
8686
8787
polygon = Polygon([
@@ -91,44 +91,40 @@ def bounds_to_extent(bounds: Tuple[float, float, float, float]) -> List[float]:
9191
figure = plt.figure(figsize=(10, 8), dpi=100)
9292
axes = plt.subplot(projection=ccrs.PlateCarree())
9393
axes.set_extent(bounds_to_extent(polygon.buffer(0.1).bounds))
94-
axes.add_patch(polygon_to_patch(polygon))
95-
figure.show()
9694
"""
9795
minx, miny, maxx, maxy = bounds
9896
return [minx, maxx, miny, maxy]
9997

10098

10199
@_requires_plot
102-
def polygon_to_patch(polygon: Polygon, **kwargs: Any) -> patches.Polygon:
103-
"""
104-
Convert a :class:`shapely.geometry.Polygon <Polygon>` to a
105-
:class:`matplotlib.patches.Polygon`.
106-
"""
107-
return patches.Polygon(np.transpose(polygon.exterior.xy), **kwargs)
108-
109-
110-
@_requires_plot
111-
def polygons_to_patch_collection(
100+
def polygons_to_collection(
112101
polygons: Iterable[Polygon],
113102
**kwargs: Any,
114-
) -> PatchCollection:
103+
) -> PolyCollection:
115104
"""
116105
Convert a list of Shapely :class:`Polygons <Polygon>`
117-
to a matplotlib :class:`~matplotlib.collections.PatchCollection`.
106+
to a matplotlib :class:`~matplotlib.collections.PolyCollection`.
118107
119108
Parameters
120109
----------
121-
polygons : iterable of `Polygon`
122-
The polygons for the patch collection
110+
polygons : iterable of Shapely :class:`Polygons <Polygon>`
111+
The polygons for the poly collection
123112
**kwargs : Any
124-
Keyword arguments to pass to the PatchCollection constructor.
113+
Keyword arguments to pass to the PolyCollection constructor.
125114
126115
Returns
127116
-------
128-
:class:`matplotlib.collections.PatchCollection`
129-
The PatchCollection made up of the polygons passed in.
117+
:class:`matplotlib.collections.PolyCollection`
118+
A PolyCollection made up of the polygons passed in.
130119
"""
131-
return PatchCollection(map(polygon_to_patch, polygons), **kwargs)
120+
return PolyCollection(
121+
verts=[
122+
np.asarray(polygon.exterior.coords)
123+
for polygon in polygons
124+
],
125+
closed=False,
126+
**kwargs
127+
)
132128

133129

134130
@_requires_plot
@@ -154,7 +150,7 @@ def plot_on_figure(
154150
This is used to build the polygons and vector quivers.
155151
scalar : :class:`xarray.DataArray`, optional
156152
The data to plot as an :class:`xarray.DataArray`.
157-
This will be passed to :meth:`.Convention.make_patch_collection`.
153+
This will be passed to :meth:`.Convention.make_poly_collection`.
158154
vector : tuple of :class:`numpy.ndarray`, optional
159155
The *u* and *v* components of a vector field
160156
as a tuple of :class:`xarray.DataArray`.
@@ -175,18 +171,18 @@ def plot_on_figure(
175171

176172
if scalar is None and vector is None:
177173
# Plot the polygon shapes for want of anything else to draw
178-
patches = convention.make_patch_collection()
179-
axes.add_collection(patches)
174+
collection = convention.make_poly_collection()
175+
axes.add_collection(collection)
180176
if title is None:
181177
title = 'Geometry'
182178

183179
if scalar is not None:
184180
# Plot a scalar variable on the polygons using a colour map
185-
patches = convention.make_patch_collection(
181+
collection = convention.make_poly_collection(
186182
scalar, cmap='jet', edgecolor='face')
187-
axes.add_collection(patches)
183+
axes.add_collection(collection)
188184
units = scalar.attrs.get('units')
189-
figure.colorbar(patches, ax=axes, location='right', label=units)
185+
figure.colorbar(collection, ax=axes, location='right', label=units)
190186

191187
if vector is not None:
192188
# Plot a vector variable using a quiver
@@ -230,7 +226,7 @@ def animate_on_figure(
230226
The coordinate values to vary across frames in the animation.
231227
scalar : :class:`xarray.DataArray`, optional
232228
The data to plot as an :class:`xarray.DataArray`.
233-
This will be passed to :meth:`.Convention.make_patch_collection`.
229+
This will be passed to :meth:`.Convention.make_poly_collection`.
234230
It should have horizontal dimensions appropriate for this convention,
235231
and a dimension matching the ``coordinate`` parameter.
236232
vector : tuple of :class:`numpy.ndarray`, optional
@@ -273,17 +269,17 @@ def animate_on_figure(
273269
axes.set_aspect(aspect='equal', adjustable='datalim')
274270
axes.title.set_animated(True)
275271

276-
patches = None
272+
collection = None
277273
if scalar is not None:
278274
# Plot a scalar variable on the polygons using a colour map
279275
scalar_values = convention.make_linear(scalar).values[:, convention.mask]
280-
patches = convention.make_patch_collection(
276+
collection = convention.make_poly_collection(
281277
cmap='jet', edgecolor='face',
282278
clim=(np.nanmin(scalar_values), np.nanmax(scalar_values)))
283-
axes.add_collection(patches)
284-
patches.set_animated(True)
279+
axes.add_collection(collection)
280+
collection.set_animated(True)
285281
units = scalar.attrs.get('units')
286-
figure.colorbar(patches, ax=axes, location='right', label=units)
282+
figure.colorbar(collection, ax=axes, location='right', label=units)
287283

288284
quiver = None
289285
if vector is not None:
@@ -333,9 +329,9 @@ def animate(index: int) -> Iterable[Artist]:
333329
changes.extend(gridlines.xline_artists)
334330
changes.extend(gridlines.yline_artists)
335331

336-
if patches is not None:
337-
patches.set_array(scalar_values[index])
338-
changes.append(patches)
332+
if collection is not None:
333+
collection.set_array(scalar_values[index])
334+
changes.append(collection)
339335

340336
if quiver is not None:
341337
quiver.set_UVC(vector_u_values[index], vector_v_values[index])

tests/conventions/test_base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -325,36 +325,36 @@ def test_face_centres():
325325

326326

327327
@pytest.mark.matplotlib
328-
def test_make_patch_collection():
328+
def test_make_poly_collection():
329329
dataset = xr.Dataset({
330330
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
331331
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
332332
})
333333
convention = SimpleConvention(dataset)
334334

335-
patches = convention.make_patch_collection(cmap='plasma', edgecolor='black')
335+
patches = convention.make_poly_collection(cmap='plasma', edgecolor='black')
336336
assert len(patches.get_paths()) == len(convention.polygons[convention.mask])
337337
assert patches.get_cmap().name == 'plasma'
338338
# Colours get transformed in to RGBA arrays
339339
np.testing.assert_equal(patches.get_edgecolor(), [[0., 0., 0., 1.0]])
340340

341341

342-
def test_make_patch_collection_data_array():
342+
def test_make_poly_collection_data_array():
343343
dataset = xr.Dataset({
344344
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
345345
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
346346
})
347347
convention = SimpleConvention(dataset)
348348

349-
patches = convention.make_patch_collection(data_array='botz')
349+
patches = convention.make_poly_collection(data_array='botz')
350350
assert len(patches.get_paths()) == len(convention.polygons[convention.mask])
351351

352352
values = convention.make_linear(dataset.data_vars['botz'])[convention.mask]
353353
np.testing.assert_equal(patches.get_array(), values)
354354
assert patches.get_clim() == (np.nanmin(values), np.nanmax(values))
355355

356356

357-
def test_make_patch_collection_data_array_and_array():
357+
def test_make_poly_collection_data_array_and_array():
358358
dataset = xr.Dataset({
359359
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
360360
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
@@ -365,22 +365,22 @@ def test_make_patch_collection_data_array_and_array():
365365

366366
with pytest.raises(TypeError):
367367
# Passing both array and data_array is a TypeError
368-
convention.make_patch_collection(data_array='botz', array=array)
368+
convention.make_poly_collection(data_array='botz', array=array)
369369

370370

371-
def test_make_patch_collection_data_array_and_clim():
371+
def test_make_poly_collection_data_array_and_clim():
372372
dataset = xr.Dataset({
373373
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
374374
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
375375
})
376376
convention = SimpleConvention(dataset)
377377

378378
# You can override the default clim if you want
379-
patches = convention.make_patch_collection(data_array='botz', clim=(-12, -8))
379+
patches = convention.make_poly_collection(data_array='botz', clim=(-12, -8))
380380
assert patches.get_clim() == (-12, -8)
381381

382382

383-
def test_make_patch_collection_data_array_dimensions():
383+
def test_make_poly_collection_data_array_dimensions():
384384
dataset = xr.Dataset({
385385
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
386386
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
@@ -389,12 +389,12 @@ def test_make_patch_collection_data_array_dimensions():
389389

390390
with pytest.raises(ValueError):
391391
# temp needs subsetting first, so this should raise an error
392-
convention.make_patch_collection(data_array='temp')
392+
convention.make_poly_collection(data_array='temp')
393393

394394
# One way to avoid this is to isel the data array
395-
convention.make_patch_collection(data_array=dataset.data_vars['temp'].isel(z=0, t=0))
395+
convention.make_poly_collection(data_array=dataset.data_vars['temp'].isel(z=0, t=0))
396396

397397
# Another way to avoid this is to isel the dataset
398398
dataset_0 = dataset.isel(z=0, t=0)
399399
convention = SimpleConvention(dataset_0)
400-
convention.make_patch_collection(data_array='temp')
400+
convention.make_poly_collection(data_array='temp')

0 commit comments

Comments
 (0)