1111import shapely
1212import xarray
1313from cartopy import crs
14- from matplotlib import animation , cm , pyplot
14+ from matplotlib import animation , colormaps , pyplot
1515from matplotlib .artist import Artist
1616from matplotlib .axes import Axes
1717from matplotlib .collections import PolyCollection
2020from matplotlib .ticker import EngFormatter , Formatter
2121
2222from emsarray .conventions import Convention , Index
23- from emsarray .plot import _requires_plot
23+ from emsarray .plot import _requires_plot , make_plot_title
2424from emsarray .types import Landmark
2525from emsarray .utils import move_dimensions_to_end
2626
@@ -62,7 +62,7 @@ def plot(
6262 figure = pyplot .figure (layout = "constrained" , figsize = figsize )
6363 transect = Transect (dataset , line )
6464 transect .plot_on_figure (figure , data_array , ** kwargs )
65- figure .show ()
65+ pyplot .show ()
6666 return figure
6767
6868
@@ -493,6 +493,10 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra
493493 The input data array transformed to have the correct shape
494494 for plotting on the transect.
495495 """
496+ # Some of the following operations drop attrs,
497+ # so keep a reference to the original ones
498+ attrs = data_array .attrs
499+
496500 data_array = self .convention .ravel (data_array )
497501
498502 depth_dimension = self .transect_dataset .coords ['depth' ].dims [0 ]
@@ -502,6 +506,9 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra
502506 linear_indices = self .transect_dataset ['linear_index' ].values
503507 data_array = data_array .isel ({index_dimension : linear_indices })
504508
509+ # Restore attrs after reformatting
510+ data_array .attrs .update (attrs )
511+
505512 return data_array
506513
507514 def _find_depth_bounds (self , data_array : xarray .DataArray ) -> Tuple [int , int ]:
@@ -749,11 +756,11 @@ def _plot_on_figure(
749756 axes .set_ylim (depth_limit_deep , depth_limit_shallow )
750757
751758 if title is None :
752- title = data_array . attrs . get ( 'long_name' )
759+ title = make_plot_title ( self . dataset , data_array )
753760 if title is not None :
754761 axes .set_title (title )
755762
756- cmap = cm . get_cmap ( cmap ) .copy ()
763+ cmap = colormaps [ cmap ] .copy ()
757764 cmap .set_bad (ocean_floor_colour )
758765 collection = self .make_poly_collection (
759766 cmap = cmap , clim = (numpy .nanmin (data_array ), numpy .nanmax (data_array )))
0 commit comments