Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docstub.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[tool.docstub.type_prefixes]
numbers = "numbers"
xarray = "xarray"
pandas = "pandas"
numpyro = "numpyro"
emcee = "emcee"
cmdstanpy = "cmdstanpy"
stan = "stan"
types = "types"

# Specify human-friendly aliases that can be used instead of Python-parsable
# annotations.
[tool.docstub.type_nicknames]
sequence = "Sequence"
iterable = "Iterable"
mapping = "Mapping"
hashable = "Hashable"
hashable_key = "Any"
iterator = "Iterator"
module = "types.ModuleType"
callable = "Callable"
labeller = "arviz_base.labels.Labeller"
Dataset = "xarray.Dataset"
DataTree = "xarray.DataTree"
# TODO: define a meta/pseudo type for DataTree-like
DataTree-like = "xarray.DataTree"
DataArray = "xarray.DataArray"
scalar = "numbers.Number"
any = "Any"
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/forest_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def plot_forest(
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, like”, “regex}, default None
filter_vars : {None, "like", "regex"}, default None
If None, interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
Expand All @@ -89,7 +89,7 @@ def plot_forest(
Which point estimate to plot. Defaults to rcParam :data:`stats.point_estimate`
ci_kind : {"eti", "hdi"}, optional
Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]``
ci_probs : (float, float), optional
ci_probs : array-like of shape (2,), optional
Indicates the probabilities that should be contained within the plotted credible intervals.
It should be sorted as the elements refer to the probabilities of the "trunk" and "twig"
elements. Defaults to ``(0.5, rcParams["stats.ci_prob"])``
Expand Down
63 changes: 63 additions & 0 deletions src/arviz_plots/plots/forest_plot.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# File generated with docstub

from collections.abc import Hashable, Mapping, Sequence
from importlib import import_module
from typing import Any, Literal

import arviz_stats
import numpy as np
import xarray
import xarray as xr
from _typeshed import Incomplete
from _typeshed import Incomplete as labeller
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from numpy.typing import ArrayLike

from arviz_plots.plot_collection import PlotCollection, process_facet_dims
from arviz_plots.plots.utils import filter_aes, get_visual_kwargs, process_group_variables_coords
from arviz_plots.visuals import annotate_label, fill_between_y, line_x, remove_axis, scatter_x

def plot_forest(
dt: xarray.DataTree | dict[str, xarray.DataTree],
*,
var_names: str | list[str] | None = ...,
filter_vars: Literal[None, "like", "regex"] | None = ...,
group: str = ...,
coords: dict | None = ...,
sample_dims: str | Sequence[Hashable] | None = ...,
combined: bool = ...,
point_estimate: Literal["mean", "median", "mode"] | None = ...,
ci_kind: Literal["eti", "hdi"] | None = ...,
ci_probs: ArrayLike | None = ...,
labels: Sequence[str] | None = ...,
shade_label: str | None = ...,
plot_collection: PlotCollection | None = ...,
backend: Literal["matplotlib", "bokeh"] | None = ...,
labeller: labeller | None = ...,
aes_by_visuals: Mapping[
Literal[
"credible_interval",
"point_estimate",
"labels",
"shade",
],
Sequence[str],
] = ...,
visuals: Mapping[
Literal[
"trunk",
"twig",
"point_estimate",
"labels",
"shade",
"ticklabels",
"remove_axis",
],
Mapping[str, Any] | bool,
] = ...,
stats: Mapping[
Literal["trunk", "twig", "point_estimate"], Mapping[str, Any] | xr.Dataset
] = ...,
**pc_kwargs: Incomplete,
) -> PlotCollection: ...