@@ -303,18 +303,18 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303
303
304
304
return (fig , ax )
305
305
306
- def get_plot_data (self ) -> pd .DataFrame :
307
- """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308
-
309
- Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310
- depending on the model type.
311
- """
312
- if isinstance (self .model , PyMCModel ):
313
- return self .get_plot_data_bayesian ()
314
- elif isinstance (self .model , RegressorMixin ):
315
- return self .get_plot_data_ols ()
316
- else :
317
- raise ValueError ("Unsupported model type" )
306
+ # def get_plot_data(self) -> pd.DataFrame:
307
+ # """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308
+
309
+ # Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310
+ # depending on the model type.
311
+ # """
312
+ # if isinstance(self.model, PyMCModel):
313
+ # return self.get_plot_data_bayesian()
314
+ # elif isinstance(self.model, RegressorMixin):
315
+ # return self.get_plot_data_ols()
316
+ # else:
317
+ # raise ValueError("Unsupported model type")
318
318
319
319
def get_plot_data_bayesian (self ) -> pd .DataFrame :
320
320
"""
@@ -323,29 +323,42 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
323
323
if isinstance (self .model , PyMCModel ):
324
324
pre_data = self .datapre .copy ()
325
325
post_data = self .datapost .copy ()
326
+ # PREDICTIONS
326
327
pre_data ["prediction" ] = (
327
- az .extract (
328
- self .pre_pred , group = "posterior_predictive" , var_names = "mu"
329
- )
328
+ az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
330
329
.mean ("sample" )
331
330
.values
332
331
)
333
332
post_data ["prediction" ] = (
334
- az .extract (
335
- self .post_pred , group = "posterior_predictive" , var_names = "mu"
336
- )
333
+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
337
334
.mean ("sample" )
338
335
.values
339
336
)
337
+ # HDI
338
+ pre_hdi = (
339
+ az .hdi (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
340
+ .to_dataframe ()
341
+ .unstack (level = "hdi" )
342
+ .droplevel (0 , axis = 1 )
343
+ )
344
+ post_hdi = (
345
+ az .hdi (self .post_pred ["posterior_predictive" ].mu , hdi_prob = 0.94 )
346
+ .to_dataframe ()
347
+ .unstack (level = "hdi" )
348
+ .droplevel (0 , axis = 1 )
349
+ )
350
+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = pre_hdi
351
+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = post_hdi
352
+ # IMPACT
340
353
pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
341
354
post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
342
-
355
+
343
356
self .data_plot = pd .concat ([pre_data , post_data ])
344
357
345
358
return self .data_plot
346
359
else :
347
360
raise ValueError ("Unsupported model type" )
348
-
361
+
349
362
def get_plot_data_ols (self ) -> pd .DataFrame :
350
363
"""
351
364
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
0 commit comments