@@ -303,6 +303,63 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303
303
304
304
return (fig , ax )
305
305
306
+ def get_plot_data (self ) -> pd .DataFrame :
307
+ """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308
+
309
+ Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310
+ depending on the model type.
311
+ """
312
+ if isinstance (self .model , PyMCModel ):
313
+ return self .get_plot_data_bayesian ()
314
+ elif isinstance (self .model , RegressorMixin ):
315
+ return self .get_plot_data_ols ()
316
+ else :
317
+ raise ValueError ("Unsupported model type" )
318
+
319
+ def get_plot_data_bayesian (self ) -> pd .DataFrame :
320
+ """
321
+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
322
+ """
323
+ if isinstance (self .model , PyMCModel ):
324
+ pre_data = self .datapre .copy ()
325
+ post_data = self .datapost .copy ()
326
+ pre_data ["prediction" ] = (
327
+ az .extract (
328
+ self .pre_pred , group = "posterior_predictive" , var_names = "mu"
329
+ )
330
+ .mean ("sample" )
331
+ .values
332
+ )
333
+ post_data ["prediction" ] = (
334
+ az .extract (
335
+ self .post_pred , group = "posterior_predictive" , var_names = "mu"
336
+ )
337
+ .mean ("sample" )
338
+ .values
339
+ )
340
+ pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
341
+ post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
342
+
343
+ self .data_plot = pd .concat ([pre_data , post_data ])
344
+
345
+ return self .data_plot
346
+ else :
347
+ raise ValueError ("Unsupported model type" )
348
+
349
+ def get_plot_data_ols (self ) -> pd .DataFrame :
350
+ """
351
+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
352
+ """
353
+ pre_data = self .datapre .copy ()
354
+ post_data = self .datapost .copy ()
355
+ pre_data ["prediction" ] = self .pre_pred
356
+ post_data ["prediction" ] = self .post_pred
357
+ pre_data ["impact" ] = self .pre_impact
358
+ post_data ["impact" ] = self .post_impact
359
+ self .data_plot = pd .concat ([pre_data , post_data ])
360
+
361
+ return self .data_plot
362
+
306
363
307
364
class InterruptedTimeSeries (PrePostFit ):
308
365
"""
0 commit comments