@@ -303,14 +303,14 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303
303
304
304
return (fig , ax )
305
305
306
- def get_plot_data_bayesian (self ) -> pd .DataFrame :
306
+ def get_plot_data_bayesian (self , hdi_prob = 0.94 ) -> pd .DataFrame :
307
307
"""
308
308
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309
309
"""
310
310
if isinstance (self .model , PyMCModel ):
311
311
pre_data = self .datapre .copy ()
312
312
post_data = self .datapost .copy ()
313
- # PREDICTIONS
313
+
314
314
pre_data ["prediction" ] = (
315
315
az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
316
316
.mean ("sample" )
@@ -321,15 +321,13 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
321
321
.mean ("sample" )
322
322
.values
323
323
)
324
- # HDI
325
- pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu )
326
- post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu )
327
- # IMPACT
324
+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
325
+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
326
+
328
327
pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
329
328
post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
330
- # HDI IMPACT
331
- pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact )
332
- post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact )
329
+ pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
330
+ post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
333
331
334
332
self .data_plot = pd .concat ([pre_data , post_data ])
335
333
0 commit comments