Skip to content

Commit d7680f6

Browse files
committed
utility function to retrieve hdi and clean get_plot_data_bayesian
1 parent deace7f commit d7680f6

File tree

2 files changed

+21
-72
lines changed

2 files changed

+21
-72
lines changed

causalpy/experiments/prepostfit.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sklearn.base import RegressorMixin
2626

2727
from causalpy.custom_exceptions import BadIndexException
28-
from causalpy.plot_utils import plot_xY
28+
from causalpy.plot_utils import plot_xY, get_hdi_to_df
2929
from causalpy.pymc_models import PyMCModel
3030
from causalpy.utils import round_num
3131

@@ -303,19 +303,6 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303

304304
return (fig, ax)
305305

306-
# def get_plot_data(self) -> pd.DataFrame:
307-
# """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308-
309-
# Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310-
# depending on the model type.
311-
# """
312-
# if isinstance(self.model, PyMCModel):
313-
# return self.get_plot_data_bayesian()
314-
# elif isinstance(self.model, RegressorMixin):
315-
# return self.get_plot_data_ols()
316-
# else:
317-
# raise ValueError("Unsupported model type")
318-
319306
def get_plot_data_bayesian(self) -> pd.DataFrame:
320307
"""
321308
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
@@ -335,23 +322,14 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
335322
.values
336323
)
337324
# HDI
338-
pre_hdi = (
339-
az.hdi(self.pre_pred["posterior_predictive"].mu, hdi_prob=0.94)
340-
.to_dataframe()
341-
.unstack(level="hdi")
342-
.droplevel(0, axis=1)
343-
)
344-
post_hdi = (
345-
az.hdi(self.post_pred["posterior_predictive"].mu, hdi_prob=0.94)
346-
.to_dataframe()
347-
.unstack(level="hdi")
348-
.droplevel(0, axis=1)
349-
)
350-
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = pre_hdi
351-
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = post_hdi
325+
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu)
326+
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu)
352327
# IMPACT
353328
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
354329
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
330+
# HDI IMPACT
331+
pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact)
332+
post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact)
355333

356334
self.data_plot = pd.concat([pre_data, post_data])
357335

causalpy/plot_utils.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,51 +82,22 @@ def plot_xY(
8282
return (h_line, h_patch)
8383

8484

85-
def get_prepostfit_data(result) -> pd.DataFrame:
85+
def get_hdi_to_df(
86+
x: xr.DataArray,
87+
hdi_prob: float = 0.94,
88+
) -> pd.DataFrame:
8689
"""
87-
Utility function to recover the data of a PrePostFit experiment along with the prediction and causal impact information.
90+
Utility function to calculate and recover HDI intervals.
8891
89-
:param result:
90-
The result of a PrePostFit experiment
92+
:param x:
93+
Xarray data array
94+
:param hdi_prob:
95+
The size of the HDI, default is 0.94
9196
"""
92-
93-
from causalpy.experiments.prepostfit import PrePostFit
94-
from causalpy.pymc_models import PyMCModel
95-
96-
if isinstance(result, PrePostFit):
97-
pre_data = result.datapre.copy()
98-
post_data = result.datapost.copy()
99-
100-
if isinstance(result.model, PyMCModel):
101-
pre_data["prediction"] = (
102-
az.extract(
103-
result.pre_pred, group="posterior_predictive", var_names="mu"
104-
)
105-
.mean("sample")
106-
.values
97+
hdi = (
98+
az.hdi(x, hdi_prob=hdi_prob)
99+
.to_dataframe()
100+
.unstack(level="hdi")
101+
.droplevel(0, axis=1)
107102
)
108-
post_data["prediction"] = (
109-
az.extract(
110-
result.post_pred, group="posterior_predictive", var_names="mu"
111-
)
112-
.mean("sample")
113-
.values
114-
)
115-
pre_data["impact"] = result.pre_impact.mean(dim=["chain", "draw"]).values
116-
post_data["impact"] = result.post_impact.mean(dim=["chain", "draw"]).values
117-
118-
elif isinstance(result.model, RegressorMixin):
119-
pre_data["prediction"] = result.pre_pred
120-
post_data["prediction"] = result.post_pred
121-
pre_data["impact"] = result.pre_impact
122-
post_data["impact"] = result.post_impact
123-
124-
else:
125-
raise ValueError("Other model types are not supported")
126-
127-
ppf_data = pd.concat([pre_data, post_data])
128-
129-
else:
130-
raise ValueError("Other experiments are not supported")
131-
132-
return ppf_data
103+
return hdi

0 commit comments

Comments
 (0)