-
-
Notifications
You must be signed in to change notification settings - Fork 22
Adding new plotting functionality to plot_ridge #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
907bda6
db4c473
deadba7
87372d9
91cb02b
dfb3a00
3889280
3091c2e
11a0ec3
850154a
83dc50c
3b05009
084cc21
e825e16
e909022
66c108c
54413f2
1248c89
b3c6727
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,14 @@ | |||||
|
|
||||||
| from arviz_plots.plot_collection import PlotCollection, process_facet_dims | ||||||
| from arviz_plots.plots.utils import filter_aes, get_visual_kwargs, process_group_variables_coords | ||||||
| from arviz_plots.visuals import annotate_label, fill_between_y, line_xy, remove_axis | ||||||
| from arviz_plots.visuals import ( | ||||||
| annotate_label, | ||||||
| fill_between_y, | ||||||
| hist, | ||||||
| line_xy, | ||||||
| remove_axis, | ||||||
| scatter_xy, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def plot_ridge( | ||||||
|
|
@@ -29,6 +36,7 @@ def plot_ridge( | |||||
| plot_collection=None, | ||||||
| backend=None, | ||||||
| labeller=None, | ||||||
| kind=None, | ||||||
| aes_by_visuals: Mapping[ | ||||||
| Literal[ | ||||||
| "edge", | ||||||
|
|
@@ -89,6 +97,9 @@ def plot_ridge( | |||||
| except "chain" and "model" (if present). The order of `labels` is ignored, | ||||||
| only elements being present in it matters. | ||||||
| It can include the special "__variable__" indicator, and does so by default. | ||||||
| kind : {"kde", "ecdf", "hist", "dot"}, optional | ||||||
| How to represent the marginal density. | ||||||
| Defaults to ``rcParams["plot.density_kind"]`` | ||||||
| shade_label : str, default None | ||||||
| Element of `labels` that should be used to add shading horizontal strips to the plot. | ||||||
| Note that labels and credible intervals are plotted in different :term:`plots`. | ||||||
|
|
@@ -197,6 +208,10 @@ def plot_ridge( | |||||
| ] | ||||||
| if labels is None: | ||||||
| labels = labellable_dims | ||||||
| if kind is None: | ||||||
| kind = rcParams["plot.density_kind"] | ||||||
| if kind not in ("kde", "hist", "ecdf", "dot"): | ||||||
| raise ValueError("kind must be either 'kde', 'hist', 'ecdf' or 'dot'") | ||||||
| if not combined and "chain" not in distribution.dims: | ||||||
| combined = True | ||||||
|
|
||||||
|
|
@@ -346,30 +361,73 @@ def plot_ridge( | |||||
| with warnings.catch_warnings(): | ||||||
| if "model" in distribution: | ||||||
| warnings.filterwarnings("ignore", message="Your data appears to have a single") | ||||||
| density = distribution.azstats.kde(dim=edge_dims, **stats.get("dist", {})) | ||||||
| # rescaling kde | ||||||
| density.loc[{"plot_axis": "y"}] = ( | ||||||
| density.sel(plot_axis="y") | ||||||
| / density.sel(plot_axis="y").max().to_array().max() | ||||||
| * ridge_height | ||||||
| ) | ||||||
| if kind == "kde": | ||||||
| density = distribution.azstats.kde(dim=edge_dims, **stats.get("dist", {})) | ||||||
| elif kind == "hist": | ||||||
| density = distribution.azstats.histogram(dim=edge_dims, **stats.get("dist", {})) | ||||||
| elif kind == "ecdf": | ||||||
| density = distribution.azstats.ecdf(dim=edge_dims, **stats.get("dist", {})) | ||||||
| elif kind == "dot": | ||||||
| density = distribution.azstats.qds(dim=edge_dims, **stats.get("dist", {})) | ||||||
| else: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a check before this |
||||||
| raise ValueError( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has already been checked |
||||||
| f"Unsupported kind '{kind}'. " | ||||||
| "Supported kinds are 'kde', 'hist', 'ecdf', 'qds'." | ||||||
| ) | ||||||
| if kind == "hist": | ||||||
| density.loc[{"plot_axis": "histogram"}] = ( | ||||||
| density.sel(plot_axis="histogram") | ||||||
| / density.sel(plot_axis="histogram").max().to_array().max() | ||||||
| * ridge_height | ||||||
| ) | ||||||
| else: | ||||||
| density.loc[{"plot_axis": "y"}] = ( | ||||||
| density.sel(plot_axis="y") | ||||||
| / density.sel(plot_axis="y").max().to_array().max() | ||||||
| * ridge_height | ||||||
| ) | ||||||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| if face_kwargs is not False: # create face_density dataset only if required | ||||||
| _, face_aes, face_ignore = filter_aes(plot_collection, aes_by_visuals, "face", sample_dims) | ||||||
| face_density = density.rename({"plot_axis": "kwarg"}) | ||||||
| face_density = face_density.assign_coords( | ||||||
| kwarg=[ | ||||||
| "y_top" if coord == "y" else coord for coord in face_density.coords["kwarg"].values | ||||||
| ] | ||||||
| ) | ||||||
| # adding a new coord 'y_bottom' set to all zeros | ||||||
| zeros = xr.full_like(face_density.sel(kwarg="x"), 0) | ||||||
| zeros = zeros.assign_coords(kwarg=["y_bottom"]) | ||||||
| face_density = xr.concat([face_density, zeros], dim="kwarg") | ||||||
| if kind == "hist": | ||||||
| face_density = density | ||||||
| elif kind == "qds": | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the value for kind is |
||||||
| qds_face_kwargs = stats.get("dist", {}).copy() | ||||||
| qds_face_kwargs.setdefault("top_only", True) | ||||||
| face_density = distribution.azstats.qds(dim=edge_dims, **qds_face_kwargs) | ||||||
| face_density.loc[{"plot_axis": "y"}] = ( | ||||||
| face_density.sel(plot_axis="y") | ||||||
| / face_density.sel(plot_axis="y").max().to_array().max() | ||||||
| * ridge_height | ||||||
| ) | ||||||
| face_density = ( | ||||||
| face_density.rename(plot_axis="kwarg") | ||||||
| .sel(kwarg=["x", "y"]) | ||||||
| .pad(kwarg=(0, 1), constant_values=0) | ||||||
| .assign_coords(kwarg=["x", "y_top", "y_bottom"]) | ||||||
| ) | ||||||
| else: | ||||||
| face_density = density.rename({"plot_axis": "kwarg"}) | ||||||
| face_density = face_density.assign_coords( | ||||||
| kwarg=[ | ||||||
| "y_top" if coord == "y" else coord | ||||||
| for coord in face_density.coords["kwarg"].values | ||||||
| ] | ||||||
| ) | ||||||
| # adding a new coord 'y_bottom' set to all zeros | ||||||
| zeros = xr.full_like(face_density.sel(kwarg="x"), 0) | ||||||
| zeros = zeros.assign_coords(kwarg=["y_bottom"]) | ||||||
| face_density = xr.concat([face_density, zeros], dim="kwarg") | ||||||
|
|
||||||
| # computing x_range | ||||||
| if edge_kwargs is not False or face_kwargs is not False: | ||||||
| x_range = density.sel(plot_axis="x") | ||||||
| if kind == "hist": | ||||||
| x_range = xr.concat( | ||||||
| [density.sel(plot_axis="left_edges"), density.sel(plot_axis="right_edges")], | ||||||
| dim="edge", | ||||||
| ) | ||||||
| else: | ||||||
| x_range = density.sel(plot_axis="x") | ||||||
| else: | ||||||
| x_range = xr.ones_like(distribution) | ||||||
|
|
||||||
|
|
@@ -469,28 +527,57 @@ def plot_ridge( | |||||
| if edge_kwargs is not False: | ||||||
| if "color" not in edge_aes: | ||||||
| edge_kwargs.setdefault("color", "C0") | ||||||
| plot_collection.map( | ||||||
| line_xy, | ||||||
| "edge", | ||||||
| data=density, | ||||||
| ignore_aes=edge_ignore, | ||||||
| coords={"column": "ridge"}, | ||||||
| **edge_kwargs, | ||||||
| ) | ||||||
| if kind == "hist": | ||||||
| plot_collection.map( | ||||||
| hist, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to plot a step histogram? In that case, I will make relevant changes for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks. |
||||||
| "edge", | ||||||
| data=density, | ||||||
| ignore_aes=edge_ignore, | ||||||
| coords={"column": "ridge"}, | ||||||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| **edge_kwargs, | ||||||
| ) | ||||||
| elif kind == "qds": | ||||||
| plot_collection.map( | ||||||
| scatter_xy, | ||||||
| "edge", | ||||||
| data=density, | ||||||
| ignore_aes=edge_ignore, | ||||||
| coords={"column": "ridge"}, | ||||||
| **edge_kwargs, | ||||||
| ) | ||||||
| else: | ||||||
| plot_collection.map( | ||||||
| line_xy, | ||||||
| "edge", | ||||||
| data=density, | ||||||
| ignore_aes=edge_ignore, | ||||||
| coords={"column": "ridge"}, | ||||||
| **edge_kwargs, | ||||||
| ) | ||||||
|
|
||||||
| if face_kwargs is not False: | ||||||
| if "color" not in face_aes: | ||||||
| face_kwargs.setdefault("color", "C0") | ||||||
| if "alpha" not in face_aes: | ||||||
| face_kwargs.setdefault("alpha", 0.4) | ||||||
| plot_collection.map( | ||||||
| fill_between_y, | ||||||
| "face", | ||||||
| data=face_density, | ||||||
| ignore_aes=face_ignore, | ||||||
| coords={"column": "ridge"}, | ||||||
| **face_kwargs, | ||||||
| ) | ||||||
| if kind in ["hist", "qds"]: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I was talking of the other branch, not this one. |
||||||
| plot_collection.map( | ||||||
| hist, | ||||||
| "face", | ||||||
| data=face_density, | ||||||
| ignore_aes=face_ignore, | ||||||
| coords={"column": "ridge"}, | ||||||
| **face_kwargs, | ||||||
| ) | ||||||
| else: | ||||||
| plot_collection.map( | ||||||
| fill_between_y, | ||||||
| "face", | ||||||
| data=face_density, | ||||||
| ignore_aes=face_ignore, | ||||||
| coords={"column": "ridge"}, | ||||||
| **face_kwargs, | ||||||
| ) | ||||||
|
|
||||||
| if shade_label is not None: | ||||||
| plot_bknd.xlim(xlim_labels, plot_collection.get_target(None, {"column": "labels"})) | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.