@@ -303,6 +303,63 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303
304304 return (fig , ax )
305305
306+ def get_plot_data (self ) -> pd .DataFrame :
307+ """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308+
309+ Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310+ depending on the model type.
311+ """
312+ if isinstance (self .model , PyMCModel ):
313+ return self .get_plot_data_bayesian ()
314+ elif isinstance (self .model , RegressorMixin ):
315+ return self .get_plot_data_ols ()
316+ else :
317+ raise ValueError ("Unsupported model type" )
318+
319+ def get_plot_data_bayesian (self ) -> pd .DataFrame :
320+ """
321+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
322+ """
323+ if isinstance (self .model , PyMCModel ):
324+ pre_data = self .datapre .copy ()
325+ post_data = self .datapost .copy ()
326+ pre_data ["prediction" ] = (
327+ az .extract (
328+ self .pre_pred , group = "posterior_predictive" , var_names = "mu"
329+ )
330+ .mean ("sample" )
331+ .values
332+ )
333+ post_data ["prediction" ] = (
334+ az .extract (
335+ self .post_pred , group = "posterior_predictive" , var_names = "mu"
336+ )
337+ .mean ("sample" )
338+ .values
339+ )
340+ pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
341+ post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
342+
343+ self .data_plot = pd .concat ([pre_data , post_data ])
344+
345+ return self .data_plot
346+ else :
347+ raise ValueError ("Unsupported model type" )
348+
349+ def get_plot_data_ols (self ) -> pd .DataFrame :
350+ """
351+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
352+ """
353+ pre_data = self .datapre .copy ()
354+ post_data = self .datapost .copy ()
355+ pre_data ["prediction" ] = self .pre_pred
356+ post_data ["prediction" ] = self .post_pred
357+ pre_data ["impact" ] = self .pre_impact
358+ post_data ["impact" ] = self .post_impact
359+ self .data_plot = pd .concat ([pre_data , post_data ])
360+
361+ return self .data_plot
362+
306363
307364class InterruptedTimeSeries (PrePostFit ):
308365 """
0 commit comments