Skip to content

Commit b79e6ee

Browse files
committed
tested for its and index alignment in recovering hdi
1 parent 1734486 commit b79e6ee

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

causalpy/experiments/prepostfit.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303

304304
return (fig, ax)
305305

306-
def get_plot_data_bayesian(self, hdi_prob=0.94) -> pd.DataFrame:
306+
def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
307307
"""
308308
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309309
"""
@@ -321,17 +321,17 @@ def get_plot_data_bayesian(self, hdi_prob=0.94) -> pd.DataFrame:
321321
.mean("sample")
322322
.values
323323
)
324-
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob)
325-
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob)
324+
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(pre_data.index)
325+
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(post_data.index)
326326

327327
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
328328
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
329-
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)
330-
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob)
329+
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob).set_index(pre_data.index)
330+
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob).set_index(post_data.index)
331331

332-
self.data_plot = pd.concat([pre_data, post_data])
332+
self.plot_data = pd.concat([pre_data, post_data])
333333

334-
return self.data_plot
334+
return self.plot_data
335335
else:
336336
raise ValueError("Unsupported model type")
337337

@@ -345,9 +345,9 @@ def get_plot_data_ols(self) -> pd.DataFrame:
345345
post_data["prediction"] = self.post_pred
346346
pre_data["impact"] = self.pre_impact
347347
post_data["impact"] = self.post_impact
348-
self.data_plot = pd.concat([pre_data, post_data])
348+
self.plot_data = pd.concat([pre_data, post_data])
349349

350-
return self.data_plot
350+
return self.plot_data
351351

352352

353353
class InterruptedTimeSeries(PrePostFit):

0 commit comments

Comments
 (0)