Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def copy(
HueStyleOptions = Literal["continuous", "discrete"] | None
AspectOptions = Union[Literal["auto", "equal"], float, None]
ExtendOptions = Literal["neither", "both", "min", "max"] | None

NormOptions = Literal[
"asinh", "function", "functionlog", "linear", "log", "logit", "symlog"
]

_T_co = TypeVar("_T_co", covariant=True)

Expand Down
41 changes: 21 additions & 20 deletions xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
AspectOptions,
ExtendOptions,
HueStyleOptions,
NormOptions,
ScaleOptions,
T_DataArray,
)
Expand Down Expand Up @@ -866,7 +867,7 @@ def newplotfunc(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1142,7 +1143,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1183,7 +1184,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1224,7 +1225,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs,
Expand Down Expand Up @@ -1438,7 +1439,7 @@ def newplotfunc(
yticks: ArrayLike | None = None,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> Any:
# All 2d plots in xarray share this function signature.
Expand Down Expand Up @@ -1692,7 +1693,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> AxesImage: ...

Expand Down Expand Up @@ -1732,7 +1733,7 @@ def imshow(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -1772,7 +1773,7 @@ def imshow(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -1909,7 +1910,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> QuadContourSet: ...

Expand Down Expand Up @@ -1949,7 +1950,7 @@ def contour(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -1989,7 +1990,7 @@ def contour(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2042,7 +2043,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> QuadContourSet: ...

Expand Down Expand Up @@ -2082,7 +2083,7 @@ def contourf(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2122,7 +2123,7 @@ def contourf(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2175,7 +2176,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> QuadMesh: ...

Expand Down Expand Up @@ -2215,7 +2216,7 @@ def pcolormesh(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2255,7 +2256,7 @@ def pcolormesh(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2359,7 +2360,7 @@ def surface(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> Poly3DCollection: ...

Expand Down Expand Up @@ -2399,7 +2400,7 @@ def surface(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down Expand Up @@ -2439,7 +2440,7 @@ def surface(
yticks: ArrayLike | None = None,
xlim: ArrayLike | None = None,
ylim: ArrayLike | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
**kwargs: Any,
) -> FacetGrid[T_DataArray]: ...

Expand Down
23 changes: 12 additions & 11 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AspectOptions,
ExtendOptions,
HueStyleOptions,
NormOptions,
ScaleOptions,
)
from xarray.plot.facetgrid import FacetGrid
Expand Down Expand Up @@ -183,7 +184,7 @@ def newplotfunc(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
robust: bool | None = None,
Expand Down Expand Up @@ -345,7 +346,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -382,7 +383,7 @@ def quiver(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -419,7 +420,7 @@ def quiver(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -496,7 +497,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -533,7 +534,7 @@ def streamplot(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -570,7 +571,7 @@ def streamplot(
cbar_ax: Axes | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
infer_intervals: bool | None = None,
center: float | None = None,
levels: ArrayLike | None = None,
Expand Down Expand Up @@ -783,7 +784,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -824,7 +825,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -865,7 +866,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -906,7 +907,7 @@ def scatter(
cmap: str | Colormap | None = None,
vmin: float | None = None,
vmax: float | None = None,
norm: Normalize | None = None,
norm: NormOptions | Normalize | None = None,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
**kwargs: Any,
Expand Down
26 changes: 24 additions & 2 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import AspectOptions, ScaleOptions
from xarray.core.types import AspectOptions, NormOptions, ScaleOptions

try:
import matplotlib.pyplot as plt
Expand All @@ -58,6 +58,26 @@
_LINEWIDTH_RANGE = (1.5, 1.5, 6.0)


def _make_norm_from_string(
norm: NormOptions,
) -> type[Normalize]:
"""
Get norm from string.

Examples
--------
>>> _make_norm_from_string("log")
<class 'matplotlib.colors.LogScaleNorm'>

"""
from matplotlib.colors import Normalize, make_norm_from_scale
from matplotlib.scale import scale_factory

scale = type(scale_factory(norm, None)) # type: ignore [arg-type] # mpl issue, use of ax is discouraged

return make_norm_from_scale(scale, Normalize)


def _determine_extend(calc_data, vmin, vmax):
extend_min = calc_data.min() < vmin
extend_max = calc_data.max() > vmax
Expand Down Expand Up @@ -264,6 +284,8 @@ def _determine_cmap_params(

# now check norm and harmonize with vmin, vmax
if norm is not None:
norm = _make_norm_from_string(norm)() if isinstance(norm, str) else norm

if norm.vmin is None:
norm.vmin = vmin
else:
Expand All @@ -279,7 +301,7 @@ def _determine_cmap_params(
vmax = norm.vmax

# if BoundaryNorm, then set levels
if isinstance(norm, mpl.colors.BoundaryNorm):
if isinstance(norm, mpl.colors.BoundaryNorm):
levels = norm.boundaries

# Choose default colormaps if not provided
Expand Down
Loading