diff --git a/docstub.toml b/docstub.toml new file mode 100644 index 00000000..6c9e7cc3 --- /dev/null +++ b/docstub.toml @@ -0,0 +1,29 @@ +[tool.docstub.type_prefixes] +numbers = "numbers" +xarray = "xarray" +pandas = "pandas" +numpyro = "numpyro" +emcee = "emcee" +cmdstanpy = "cmdstanpy" +stan = "stan" +types = "types" + +# Specify human-friendly aliases that can be used instead of Python-parsable +# annotations. +[tool.docstub.type_nicknames] +sequence = "Sequence" +iterable = "Iterable" +mapping = "Mapping" +hashable = "Hashable" +hashable_key = "Any" +iterator = "Iterator" +module = "types.ModuleType" +callable = "Callable" +labeller = "arviz_base.labels.Labeller" +Dataset = "xarray.Dataset" +DataTree = "xarray.DataTree" +# TODO: define a meta/pseudo type for DataTree-like +DataTree-like = "xarray.DataTree" +DataArray = "xarray.DataArray" +scalar = "numbers.Number" +any = "Any" diff --git a/src/arviz_plots/plots/forest_plot.py b/src/arviz_plots/plots/forest_plot.py index d6ec2e5d..73b7c37b 100644 --- a/src/arviz_plots/plots/forest_plot.py +++ b/src/arviz_plots/plots/forest_plot.py @@ -72,7 +72,7 @@ def plot_forest( var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. - filter_vars : {None, “like”, “regex”}, default None + filter_vars : {None, "like", "regex"}, default None If None, interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. @@ -89,7 +89,7 @@ def plot_forest( Which point estimate to plot. Defaults to rcParam :data:`stats.point_estimate` ci_kind : {"eti", "hdi"}, optional Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]`` - ci_probs : (float, float), optional + ci_probs : array-like of shape (2,), optional Indicates the probabilities that should be contained within the plotted credible intervals. It should be sorted as the elements refer to the probabilities of the "trunk" and "twig" elements. Defaults to ``(0.5, rcParams["stats.ci_prob"])`` diff --git a/src/arviz_plots/plots/forest_plot.pyi b/src/arviz_plots/plots/forest_plot.pyi new file mode 100644 index 00000000..38bee97a --- /dev/null +++ b/src/arviz_plots/plots/forest_plot.pyi @@ -0,0 +1,63 @@ +# File generated with docstub + +from collections.abc import Hashable, Mapping, Sequence +from importlib import import_module +from typing import Any, Literal + +import arviz_stats +import numpy as np +import xarray +import xarray as xr +from _typeshed import Incomplete +from _typeshed import Incomplete as labeller +from arviz_base import rcParams +from arviz_base.labels import BaseLabeller +from numpy.typing import ArrayLike + +from arviz_plots.plot_collection import PlotCollection, process_facet_dims +from arviz_plots.plots.utils import filter_aes, get_visual_kwargs, process_group_variables_coords +from arviz_plots.visuals import annotate_label, fill_between_y, line_x, remove_axis, scatter_x + +def plot_forest( + dt: xarray.DataTree | dict[str, xarray.DataTree], + *, + var_names: str | list[str] | None = ..., + filter_vars: Literal[None, "like", "regex"] | None = ..., + group: str = ..., + coords: dict | None = ..., + sample_dims: str | Sequence[Hashable] | None = ..., + combined: bool = ..., + point_estimate: Literal["mean", "median", "mode"] | None = ..., + ci_kind: Literal["eti", "hdi"] | None = ..., + ci_probs: ArrayLike | None = ..., + labels: Sequence[str] | None = ..., + shade_label: str | None = ..., + plot_collection: PlotCollection | None = ..., + backend: Literal["matplotlib", "bokeh"] | None = ..., + labeller: labeller | None = ..., + aes_by_visuals: Mapping[ + Literal[ + "credible_interval", + "point_estimate", + "labels", + "shade", + ], + Sequence[str], + ] = ..., + visuals: Mapping[ + Literal[ + "trunk", + "twig", + "point_estimate", + "labels", + "shade", + "ticklabels", + "remove_axis", + ], + Mapping[str, Any] | bool, + ] = ..., + stats: Mapping[ + Literal["trunk", "twig", "point_estimate"], Mapping[str, Any] | xr.Dataset + ] = ..., + **pc_kwargs: Incomplete, +) -> PlotCollection: ...