Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
dynamic = ["version", "description"]
dependencies = [
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats",
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@top_level_viz",
]

[tool.flit.module]
Expand Down
80 changes: 75 additions & 5 deletions src/arviz_plots/plots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import numpy as np
import xarray as xr
from arviz_base import references_to_dataset
from arviz_base import rcParams, references_to_dataset
from arviz_base.utils import _var_names
from arviz_stats import ecdf, histogram, kde

from arviz_plots.plot_collection import concat_model_dict, process_facet_dims
from arviz_plots.visuals import hline, hspan, vline, vspan
Expand Down Expand Up @@ -74,25 +75,33 @@ def process_group_variables_coords(dt, group, var_names, filter_vars, coords, al
distribution = distribution.sel(coords)
return distribution


def filter_aes(pc, aes_by_visuals, visual, sample_dims):
reduce_dims, _, artist_aes, ignore_aes = filter_aes_new(pc, aes_by_visuals, visual, sample_dims)
return reduce_dims, artist_aes, ignore_aes

def filter_aes_new(pc, aes_by_visuals, visual, sample_dims):
"""Split aesthetics and get relevant dimensions.

Returns
-------
artist_dims : list
reduce_dims : list
Dimensions that should be reduced for this visual.
That is, all dimensions in `sample_dims` that are not
mapped to any aesthetic.
active_dims : list
Dimensions that have either faceting or aesthetic mappings
active for that visual. Should not be reduced and should have
a groupby performed on them if computing summaries.
artist_aes : iterable
ignore_aes : set
"""
artist_aes = aes_by_visuals.get(visual, {})
pc_aes = pc.aes_set
ignore_aes = set(pc_aes).difference(artist_aes)
_, all_loop_dims = pc.update_aes(ignore_aes=ignore_aes)
artist_dims = [dim for dim in sample_dims if dim not in all_loop_dims]
return artist_dims, artist_aes, ignore_aes
reduce_dims = [dim for dim in sample_dims if dim not in all_loop_dims]
active_dims = [dim for dim in all_loop_dims if dim not in sample_dims]
return reduce_dims, active_dims, artist_aes, ignore_aes


def set_wrap_layout(pc_kwargs, plot_bknd, ds):
Expand Down Expand Up @@ -167,6 +176,67 @@ def set_grid_layout(pc_kwargs, plot_bknd, ds, num_rows=None, num_cols=None):
pc_kwargs["figure_kwargs"]["figsize_units"] = figsize_units
return pc_kwargs

def compute_dist(data, reduce_dims, active_dims, kind=None, stats=None):
if stats is None:
stats = {}
# quick exit if pre-computed elements in `stats`
if any(isinstance(stats.get(viz, None), xr.Dataset) for viz in ("ecdf", "hist", "kde")):
return (stats.get(viz, xr.Dataset()) for viz in ("ecdf", "hist", "kde"))
if kind is None:
kind = rcParams["plot.density_kind"]
if set(reduce_dims).intersection(active_dims):
raise ValueError("'reduce_dims' and 'active_dims' can't share elements")
ecdf_vars = []
hist_vars = []
kde_vars = []
if kind == "auto":
for var_name, da in data.items():
reduced_size = np.prod([da.sizes[dim] for dim in reduce_dims if dim in da.dims])
groupby_dims = [dim for dim in active_dims if dim in da.dims]
if groupby_dims:
reduced_size *= np.prod([np.min(np.unique(da.coords[dim], return_counts=True)[1]) for dim in groupby_dims])
if reduced_size < 100:
ecdf_vars.append(var_name)
elif da.dtype.kind == "f":
kde_vars.append(var_name)
else:
hist_vars.append(var_name)
elif kind == "ecdf":
ecdf_vars == list(data.data_vars)
elif kind == "hist":
hist_vars == list(data.data_vars)
elif kind == "kde":
kde_vars = list(data.data_vars)

if ecdf_vars:
ecdf_data = data[ecdf_vars]
groupby_dims = [dim for dim in active_dims if dim in ecdf_data.dims]
if groupby_dims:
ecdf_data = ecdf_data.groupby(groupby_dims)
ecdf_out = ecdf(ecdf_data, dim=reduce_dims, **stats.get("ecdf", {}))
else:
ecdf_out = xr.Dataset()

if hist_vars:
hist_data = data[hist_vars]
groupby_dims = [dim for dim in active_dims if dim in hist_data.dims]
if groupby_dims:
hist_data = hist_data.groupby(groupby_dims)
hist_out = histogram(hist_data, dim=reduce_dims, **stats.get("hist", {}))
else:
hist_out = xr.Dataset()

if kde_vars:
kde_data = data[kde_vars]
groupby_dims = [dim for dim in active_dims if dim in kde_data.dims]
if groupby_dims:
kde_data = kde_data.groupby(groupby_dims)
kde_out = kde(kde_data, dim=reduce_dims, **stats.get("kde", {}))
else:
kde_out = xr.Dataset()

return ecdf_out, hist_out, kde_out


def add_lines(
plot_collection,
Expand Down
Loading