@@ -52,6 +52,60 @@ def _sample_posterior(all_trees, X, rng, size=None, excluded=None):
52
52
return pred
53
53
54
54
55
+ def plot_convergence (idata , var_name = None , kind = "ecdf" , figsize = None , ax = None ):
56
+ """
57
+ Plot convergence diagnostics.
58
+
59
+ Parameters
60
+ ----------
61
+ idata : InferenceData
62
+ InferenceData object containing the posterior samples.
63
+ var_name : str
64
+ Name of the BART variable to plot. Defaults to None.
65
+ kind : str
66
+ Type of plot to display. Options are "ecdf" (default) and "kde".
67
+ figsize : tuple
68
+ Figure size. Defaults to None.
69
+ ax : matplotlib axes
70
+ Axes on which to plot. Defaults to None.
71
+
72
+ Returns
73
+ -------
74
+ ax : matplotlib axes
75
+ """
76
+ ess_threshold = idata .posterior .chain .size * 100
77
+ ess = np .atleast_2d (az .ess (idata , method = "bulk" , var_names = var_name )[var_name ].values )
78
+ rhat = np .atleast_2d (az .rhat (idata , var_names = var_name )[var_name ].values )
79
+
80
+ if figsize is None :
81
+ figsize = (10 , 3 )
82
+
83
+ if kind == "ecdf" :
84
+ kind_func = az .plot_ecdf
85
+ sharey = True
86
+ elif kind == "kde" :
87
+ kind_func = az .plot_kde
88
+ sharey = False
89
+
90
+ if ax is None :
91
+ _ , ax = plt .subplots (1 , 2 , figsize = figsize , sharex = "col" , sharey = sharey )
92
+
93
+ for idx , (essi , rhati ) in enumerate (zip (ess , rhat )):
94
+ kind_func (essi , ax = ax [0 ], plot_kwargs = {"color" : f"C{ idx } " })
95
+ ax [0 ].axvline (ess_threshold , color = "k" , ls = "--" )
96
+ kind_func (rhati , ax = ax [1 ], plot_kwargs = {"color" : f"C{ idx } " })
97
+ ax [1 ].axvline (1.01 , color = "0.6" , ls = "--" )
98
+ ax [1 ].axvline (1.05 , color = "k" , ls = "--" )
99
+
100
+ ax [0 ].set_xlabel ("ESS" )
101
+ ax [1 ].set_xlabel ("R-hat" )
102
+ if kind == "kde" :
103
+ ax [0 ].set_yticks ([])
104
+ ax [1 ].set_yticks ([])
105
+
106
+ return ax
107
+
108
+
55
109
def plot_dependence (
56
110
bartrv ,
57
111
X ,
0 commit comments