Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/arviz_plots/backend/matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,20 @@ def hist(
"""Interface to matplotlib for a histogram bar plot."""
artist_kws.setdefault("zorder", 2)
if np.any(bottom != 0):
height = y - bottom
height = y + bottom
else:
height = y
if color is not unset:
if facecolor is unset:
facecolor = color
if edgecolor is unset:
edgecolor = color

bottom = np.full_like(height, bottom)
kwargs = {"color": facecolor, "edgecolor": edgecolor, "alpha": alpha}
return target.fill_between(
np.r_[l_e, r_e[-1]],
np.r_[height, height[-1]],
np.r_[bottom, bottom[-1]],
step="post",
**_filter_kwargs(kwargs, None, artist_kws),
)
Expand Down
157 changes: 122 additions & 35 deletions src/arviz_plots/plots/ridge_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@

from arviz_plots.plot_collection import PlotCollection, process_facet_dims
from arviz_plots.plots.utils import filter_aes, get_visual_kwargs, process_group_variables_coords
from arviz_plots.visuals import annotate_label, fill_between_y, line_xy, remove_axis
from arviz_plots.visuals import (
annotate_label,
fill_between_y,
hist,
line_xy,
remove_axis,
scatter_xy,
)


def plot_ridge(
Expand All @@ -29,6 +36,7 @@ def plot_ridge(
plot_collection=None,
backend=None,
labeller=None,
kind=None,
aes_by_visuals: Mapping[
Literal[
"edge",
Expand Down Expand Up @@ -89,6 +97,9 @@ def plot_ridge(
except "chain" and "model" (if present). The order of `labels` is ignored,
only elements being present in it matters.
It can include the special "__variable__" indicator, and does so by default.
kind : {"kde", "ecdf", "hist", "dot"}, optional
How to represent the marginal density.
Defaults to ``rcParams["plot.density_kind"]``
shade_label : str, default None
Element of `labels` that should be used to add shading horizontal strips to the plot.
Note that labels and credible intervals are plotted in different :term:`plots`.
Expand Down Expand Up @@ -197,6 +208,10 @@ def plot_ridge(
]
if labels is None:
labels = labellable_dims
if kind is None:
kind = rcParams["plot.density_kind"]
if kind not in ("kde", "hist", "ecdf", "dot"):
raise ValueError("kind must be either 'kde', 'hist', 'ecdf' or 'dot'")
if not combined and "chain" not in distribution.dims:
combined = True

Expand Down Expand Up @@ -346,30 +361,73 @@ def plot_ridge(
with warnings.catch_warnings():
if "model" in distribution:
warnings.filterwarnings("ignore", message="Your data appears to have a single")
density = distribution.azstats.kde(dim=edge_dims, **stats.get("dist", {}))
# rescaling kde
density.loc[{"plot_axis": "y"}] = (
density.sel(plot_axis="y")
/ density.sel(plot_axis="y").max().to_array().max()
* ridge_height
)
if kind == "kde":
density = distribution.azstats.kde(dim=edge_dims, **stats.get("dist", {}))
elif kind == "hist":
density = distribution.azstats.histogram(dim=edge_dims, **stats.get("dist", {}))
elif kind == "ecdf":
density = distribution.azstats.ecdf(dim=edge_dims, **stats.get("dist", {}))
elif kind == "dot":
density = distribution.azstats.qds(dim=edge_dims, **stats.get("dist", {}))
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check before this

raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has already been checked

f"Unsupported kind '{kind}'. "
"Supported kinds are 'kde', 'hist', 'ecdf', 'qds'."
)
if kind == "hist":
density.loc[{"plot_axis": "histogram"}] = (
density.sel(plot_axis="histogram")
/ density.sel(plot_axis="histogram").max().to_array().max()
* ridge_height
)
else:
density.loc[{"plot_axis": "y"}] = (
density.sel(plot_axis="y")
/ density.sel(plot_axis="y").max().to_array().max()
* ridge_height
)

if face_kwargs is not False: # create face_density dataset only if required
_, face_aes, face_ignore = filter_aes(plot_collection, aes_by_visuals, "face", sample_dims)
face_density = density.rename({"plot_axis": "kwarg"})
face_density = face_density.assign_coords(
kwarg=[
"y_top" if coord == "y" else coord for coord in face_density.coords["kwarg"].values
]
)
# adding a new coord 'y_bottom' set to all zeros
zeros = xr.full_like(face_density.sel(kwarg="x"), 0)
zeros = zeros.assign_coords(kwarg=["y_bottom"])
face_density = xr.concat([face_density, zeros], dim="kwarg")
if kind == "hist":
face_density = density
elif kind == "qds":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the value for kind is dot, qds is the function in arviz-stats (similarly, kind is hist but arviz-stats function is histogram)

qds_face_kwargs = stats.get("dist", {}).copy()
qds_face_kwargs.setdefault("top_only", True)
face_density = distribution.azstats.qds(dim=edge_dims, **qds_face_kwargs)
face_density.loc[{"plot_axis": "y"}] = (
face_density.sel(plot_axis="y")
/ face_density.sel(plot_axis="y").max().to_array().max()
* ridge_height
)
face_density = (
face_density.rename(plot_axis="kwarg")
.sel(kwarg=["x", "y"])
.pad(kwarg=(0, 1), constant_values=0)
.assign_coords(kwarg=["x", "y_top", "y_bottom"])
)
else:
face_density = density.rename({"plot_axis": "kwarg"})
face_density = face_density.assign_coords(
kwarg=[
"y_top" if coord == "y" else coord
for coord in face_density.coords["kwarg"].values
]
)
# adding a new coord 'y_bottom' set to all zeros
zeros = xr.full_like(face_density.sel(kwarg="x"), 0)
zeros = zeros.assign_coords(kwarg=["y_bottom"])
face_density = xr.concat([face_density, zeros], dim="kwarg")

# computing x_range
if edge_kwargs is not False or face_kwargs is not False:
x_range = density.sel(plot_axis="x")
if kind == "hist":
x_range = xr.concat(
[density.sel(plot_axis="left_edges"), density.sel(plot_axis="right_edges")],
dim="edge",
)
else:
x_range = density.sel(plot_axis="x")
else:
x_range = xr.ones_like(distribution)

Expand Down Expand Up @@ -469,28 +527,57 @@ def plot_ridge(
if edge_kwargs is not False:
if "color" not in edge_aes:
edge_kwargs.setdefault("color", "C0")
plot_collection.map(
line_xy,
"edge",
data=density,
ignore_aes=edge_ignore,
coords={"column": "ridge"},
**edge_kwargs,
)
if kind == "hist":
plot_collection.map(
hist,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hist,
step_hist,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to plot a step histogram? In that case, I will make relevant changes for step_hist function instead of hist

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks.

"edge",
data=density,
ignore_aes=edge_ignore,
coords={"column": "ridge"},
**edge_kwargs,
)
elif kind == "qds":
plot_collection.map(
scatter_xy,
"edge",
data=density,
ignore_aes=edge_ignore,
coords={"column": "ridge"},
**edge_kwargs,
)
else:
plot_collection.map(
line_xy,
"edge",
data=density,
ignore_aes=edge_ignore,
coords={"column": "ridge"},
**edge_kwargs,
)

if face_kwargs is not False:
if "color" not in face_aes:
face_kwargs.setdefault("color", "C0")
if "alpha" not in face_aes:
face_kwargs.setdefault("alpha", 0.4)
plot_collection.map(
fill_between_y,
"face",
data=face_density,
ignore_aes=face_ignore,
coords={"column": "ridge"},
**face_kwargs,
)
if kind in ["hist", "qds"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if kind in ["hist", "qds"]:
if kind == "hist":

I was talking of the other branch, not this one.

plot_collection.map(
hist,
"face",
data=face_density,
ignore_aes=face_ignore,
coords={"column": "ridge"},
**face_kwargs,
)
else:
plot_collection.map(
fill_between_y,
"face",
data=face_density,
ignore_aes=face_ignore,
coords={"column": "ridge"},
**face_kwargs,
)

if shade_label is not None:
plot_bknd.xlim(xlim_labels, plot_collection.get_target(None, {"column": "labels"}))
Expand Down
7 changes: 5 additions & 2 deletions src/arviz_plots/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ def hist(da, target, **kwargs):
The input argument `da` is split into l_e, r_e and y using the dimension ``plot_axis``.
"""
plot_backend = backend_from_object(target)
bottom = kwargs.pop("y", 0)
y = da.sel(plot_axis="histogram").values
return plot_backend.hist(
da.sel(plot_axis="histogram"),
y,
da.sel(plot_axis="left_edges"),
da.sel(plot_axis="right_edges"),
target,
bottom=bottom,
**kwargs,
)

Expand All @@ -34,7 +37,7 @@ def step_hist(da, target, **kwargs):
r_e = da.sel(plot_axis="right_edges").values
y = da.sel(plot_axis="histogram").values

bottom = kwargs.pop("bottom", 0)
bottom = kwargs.pop("y", 0)
if np.any(bottom != 0):
height = y - bottom
else:
Expand Down