Skip to content

Commit e859ae3

Browse files
authored
add plot_converge (#67)
1 parent 6006d1d commit e859ae3

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pymc_bart.bart import BART
1717
from pymc_bart.pgbart import PGBART
18-
from pymc_bart.utils import plot_dependence, plot_variable_importance
18+
from pymc_bart.utils import plot_convergence, plot_dependence, plot_variable_importance
1919

2020
__all__ = ["BART", "PGBART"]
2121
__version__ = "0.3.2"

pymc_bart/utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,60 @@ def _sample_posterior(all_trees, X, rng, size=None, excluded=None):
5252
return pred
5353

5454

55+
def plot_convergence(idata, var_name=None, kind="ecdf", figsize=None, ax=None):
56+
"""
57+
Plot convergence diagnostics.
58+
59+
Parameters
60+
----------
61+
idata : InferenceData
62+
InferenceData object containing the posterior samples.
63+
var_name : str
64+
Name of the BART variable to plot. Defaults to None.
65+
kind : str
66+
Type of plot to display. Options are "ecdf" (default) and "kde".
67+
figsize : tuple
68+
Figure size. Defaults to None.
69+
ax : matplotlib axes
70+
Axes on which to plot. Defaults to None.
71+
72+
Returns
73+
-------
74+
ax : matplotlib axes
75+
"""
76+
ess_threshold = idata.posterior.chain.size * 100
77+
ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values)
78+
rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values)
79+
80+
if figsize is None:
81+
figsize = (10, 3)
82+
83+
if kind == "ecdf":
84+
kind_func = az.plot_ecdf
85+
sharey = True
86+
elif kind == "kde":
87+
kind_func = az.plot_kde
88+
sharey = False
89+
90+
if ax is None:
91+
_, ax = plt.subplots(1, 2, figsize=figsize, sharex="col", sharey=sharey)
92+
93+
for idx, (essi, rhati) in enumerate(zip(ess, rhat)):
94+
kind_func(essi, ax=ax[0], plot_kwargs={"color": f"C{idx}"})
95+
ax[0].axvline(ess_threshold, color="k", ls="--")
96+
kind_func(rhati, ax=ax[1], plot_kwargs={"color": f"C{idx}"})
97+
ax[1].axvline(1.01, color="0.6", ls="--")
98+
ax[1].axvline(1.05, color="k", ls="--")
99+
100+
ax[0].set_xlabel("ESS")
101+
ax[1].set_xlabel("R-hat")
102+
if kind == "kde":
103+
ax[0].set_yticks([])
104+
ax[1].set_yticks([])
105+
106+
return ax
107+
108+
55109
def plot_dependence(
56110
bartrv,
57111
X,

0 commit comments

Comments
 (0)