Skip to content

Commit 1734486

Browse files
committed
hdi_prob specification in get_plot_data_bayesian
1 parent d7680f6 commit 1734486

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

causalpy/experiments/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,25 @@ def ols_plot(self, *args, **kwargs):
8080
"""Abstract method for plotting the model."""
8181
raise NotImplementedError("ols_plot method not yet implemented")
8282

83-
def get_plot_data(self) -> pd.DataFrame:
83+
def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
8484
"""Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
8585
8686
Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
8787
depending on the model type.
8888
"""
8989
if isinstance(self.model, PyMCModel):
90-
return self.get_plot_data_bayesian()
90+
return self.get_plot_data_bayesian(*args, **kwargs)
9191
elif isinstance(self.model, RegressorMixin):
92-
return self.get_plot_data_ols()
92+
return self.get_plot_data_ols(*args, **kwargs)
9393
else:
9494
raise ValueError("Unsupported model type")
9595

9696
@abstractmethod
97-
def get_plot_data_bayesian(self):
97+
def get_plot_data_bayesian(self, *args, **kwargs):
9898
"""Abstract method for recovering plot data."""
9999
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")
100100

101101
@abstractmethod
102-
def get_plot_data_ols(self):
102+
def get_plot_data_ols(self, *args, **kwargs):
103103
"""Abstract method for recovering plot data."""
104104
raise NotImplementedError("get_plot_data_ols method not yet implemented")

causalpy/experiments/prepostfit.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,14 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303

304304
return (fig, ax)
305305

306-
def get_plot_data_bayesian(self) -> pd.DataFrame:
306+
def get_plot_data_bayesian(self, hdi_prob=0.94) -> pd.DataFrame:
307307
"""
308308
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309309
"""
310310
if isinstance(self.model, PyMCModel):
311311
pre_data = self.datapre.copy()
312312
post_data = self.datapost.copy()
313-
# PREDICTIONS
313+
314314
pre_data["prediction"] = (
315315
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
316316
.mean("sample")
@@ -321,15 +321,13 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
321321
.mean("sample")
322322
.values
323323
)
324-
# HDI
325-
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu)
326-
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu)
327-
# IMPACT
324+
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob)
325+
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob)
326+
328327
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
329328
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
330-
# HDI IMPACT
331-
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact)
332-
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact)
329+
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)
330+
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob)
333331

334332
self.data_plot = pd.concat([pre_data, post_data])
335333

0 commit comments

Comments
 (0)