25
25
from sklearn .base import RegressorMixin
26
26
27
27
from causalpy .custom_exceptions import BadIndexException
28
- from causalpy .plot_utils import plot_xY
28
+ from causalpy .plot_utils import plot_xY , get_hdi_to_df
29
29
from causalpy .pymc_models import PyMCModel
30
30
from causalpy .utils import round_num
31
31
@@ -303,19 +303,6 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303
303
304
304
return (fig , ax )
305
305
306
- # def get_plot_data(self) -> pd.DataFrame:
307
- # """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308
-
309
- # Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310
- # depending on the model type.
311
- # """
312
- # if isinstance(self.model, PyMCModel):
313
- # return self.get_plot_data_bayesian()
314
- # elif isinstance(self.model, RegressorMixin):
315
- # return self.get_plot_data_ols()
316
- # else:
317
- # raise ValueError("Unsupported model type")
318
-
319
306
def get_plot_data_bayesian (self ) -> pd .DataFrame :
320
307
"""
321
308
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
@@ -335,23 +322,14 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
335
322
.values
336
323
)
337
324
# HDI
338
- pre_hdi = (
339
- az .hdi (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
340
- .to_dataframe ()
341
- .unstack (level = "hdi" )
342
- .droplevel (0 , axis = 1 )
343
- )
344
- post_hdi = (
345
- az .hdi (self .post_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
346
- .to_dataframe ()
347
- .unstack (level = "hdi" )
348
- .droplevel (0 , axis = 1 )
349
- )
350
- pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = pre_hdi
351
- post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = post_hdi
325
+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu )
326
+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu )
352
327
# IMPACT
353
328
pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
354
329
post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
330
+ # HDI IMPACT
331
+ pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact )
332
+ post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact )
355
333
356
334
self .data_plot = pd .concat ([pre_data , post_data ])
357
335
0 commit comments