@@ -303,7 +303,7 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303
303
304
304
return (fig , ax )
305
305
306
- def get_plot_data_bayesian (self , hdi_prob = 0.94 ) -> pd .DataFrame :
306
+ def get_plot_data_bayesian (self , hdi_prob : float = 0.94 ) -> pd .DataFrame :
307
307
"""
308
308
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309
309
"""
@@ -321,17 +321,17 @@ def get_plot_data_bayesian(self, hdi_prob=0.94) -> pd.DataFrame:
321
321
.mean ("sample" )
322
322
.values
323
323
)
324
- pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
325
- post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob )
324
+ pre_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob ). set_index ( pre_data . index )
325
+ post_data [["pred_hdi_lower" , "pred_hdi_upper" ]] = get_hdi_to_df (self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob ). set_index ( post_data . index )
326
326
327
327
pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
328
328
post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
329
- pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
330
- post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob )
329
+ pre_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob ). set_index ( pre_data . index )
330
+ post_data [["impact_hdi_lower" , "impact_hdi_upper" ]] = get_hdi_to_df (self .post_impact , hdi_prob = hdi_prob ). set_index ( post_data . index )
331
331
332
- self .data_plot = pd .concat ([pre_data , post_data ])
332
+ self .plot_data = pd .concat ([pre_data , post_data ])
333
333
334
- return self .data_plot
334
+ return self .plot_data
335
335
else :
336
336
raise ValueError ("Unsupported model type" )
337
337
@@ -345,9 +345,9 @@ def get_plot_data_ols(self) -> pd.DataFrame:
345
345
post_data ["prediction" ] = self .post_pred
346
346
pre_data ["impact" ] = self .pre_impact
347
347
post_data ["impact" ] = self .post_impact
348
- self .data_plot = pd .concat ([pre_data , post_data ])
348
+ self .plot_data = pd .concat ([pre_data , post_data ])
349
349
350
- return self .data_plot
350
+ return self .plot_data
351
351
352
352
353
353
class InterruptedTimeSeries (PrePostFit ):
0 commit comments