Skip to content

Commit deace7f

Browse files
committed
generic get_plot_data in base.py and update prepostfit code to get hdi
1 parent fdc867f commit deace7f

File tree

2 files changed

+57
-20
lines changed

2 files changed

+57
-20
lines changed

causalpy/experiments/base.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from abc import abstractmethod
1919

20+
import pandas as pd
2021
from sklearn.base import RegressorMixin
2122

2223
from causalpy.pymc_models import PyMCModel
@@ -78,3 +79,26 @@ def bayesian_plot(self, *args, **kwargs):
7879
def ols_plot(self, *args, **kwargs):
7980
"""Abstract method for plotting the model."""
8081
raise NotImplementedError("ols_plot method not yet implemented")
82+
83+
def get_plot_data(self) -> pd.DataFrame:
84+
"""Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
85+
86+
Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
87+
depending on the model type.
88+
"""
89+
if isinstance(self.model, PyMCModel):
90+
return self.get_plot_data_bayesian()
91+
elif isinstance(self.model, RegressorMixin):
92+
return self.get_plot_data_ols()
93+
else:
94+
raise ValueError("Unsupported model type")
95+
96+
@abstractmethod
97+
def get_plot_data_bayesian(self):
98+
"""Abstract method for recovering plot data."""
99+
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")
100+
101+
@abstractmethod
102+
def get_plot_data_ols(self):
103+
"""Abstract method for recovering plot data."""
104+
raise NotImplementedError("get_plot_data_ols method not yet implemented")

causalpy/experiments/prepostfit.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -303,18 +303,18 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303

304304
return (fig, ax)
305305

306-
def get_plot_data(self) -> pd.DataFrame:
307-
"""Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308-
309-
Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310-
depending on the model type.
311-
"""
312-
if isinstance(self.model, PyMCModel):
313-
return self.get_plot_data_bayesian()
314-
elif isinstance(self.model, RegressorMixin):
315-
return self.get_plot_data_ols()
316-
else:
317-
raise ValueError("Unsupported model type")
306+
# def get_plot_data(self) -> pd.DataFrame:
307+
# """Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308+
309+
# Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310+
# depending on the model type.
311+
# """
312+
# if isinstance(self.model, PyMCModel):
313+
# return self.get_plot_data_bayesian()
314+
# elif isinstance(self.model, RegressorMixin):
315+
# return self.get_plot_data_ols()
316+
# else:
317+
# raise ValueError("Unsupported model type")
318318

319319
def get_plot_data_bayesian(self) -> pd.DataFrame:
320320
"""
@@ -323,29 +323,42 @@ def get_plot_data_bayesian(self) -> pd.DataFrame:
323323
if isinstance(self.model, PyMCModel):
324324
pre_data = self.datapre.copy()
325325
post_data = self.datapost.copy()
326+
# PREDICTIONS
326327
pre_data["prediction"] = (
327-
az.extract(
328-
self.pre_pred, group="posterior_predictive", var_names="mu"
329-
)
328+
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
330329
.mean("sample")
331330
.values
332331
)
333332
post_data["prediction"] = (
334-
az.extract(
335-
self.post_pred, group="posterior_predictive", var_names="mu"
336-
)
333+
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
337334
.mean("sample")
338335
.values
339336
)
337+
# HDI
338+
pre_hdi = (
339+
az.hdi(self.pre_pred["posterior_predictive"].mu, hdi_prob=0.94)
340+
.to_dataframe()
341+
.unstack(level="hdi")
342+
.droplevel(0, axis=1)
343+
)
344+
post_hdi = (
345+
az.hdi(self.post_pred["posterior_predictive"].mu, hdi_prob=0.94)
346+
.to_dataframe()
347+
.unstack(level="hdi")
348+
.droplevel(0, axis=1)
349+
)
350+
pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = pre_hdi
351+
post_data[["pred_hdi_lower", "pred_hdi_upper"]] = post_hdi
352+
# IMPACT
340353
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
341354
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
342-
355+
343356
self.data_plot = pd.concat([pre_data, post_data])
344357

345358
return self.data_plot
346359
else:
347360
raise ValueError("Unsupported model type")
348-
361+
349362
def get_plot_data_ols(self) -> pd.DataFrame:
350363
"""
351364
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.

0 commit comments

Comments
 (0)