Skip to content

Commit 70110c0

Browse files
committed
export plot data from plot utils
1 parent d31a033 commit 70110c0

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

causalpy/plot_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import xarray as xr
2525
from matplotlib.collections import PolyCollection
2626
from matplotlib.lines import Line2D
27+
from sklearn.base import RegressorMixin
2728

2829

2930
def plot_xY(
@@ -79,3 +80,53 @@ def plot_xY(
7980
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
8081
)[-1]
8182
return (h_line, h_patch)
83+
84+
85+
def export_prepostfit_data(result):
86+
"""
87+
Utility function to recover the data of a PrePostFit experiment along with prediction and causal impact information.
88+
89+
:param result:
90+
The result of a PrePostFit experiment
91+
"""
92+
93+
from causalpy.experiments.prepostfit import PrePostFit
94+
from causalpy.pymc_models import PyMCModel
95+
96+
if isinstance(result, PrePostFit):
97+
pre_data = result.datapre.copy()
98+
post_data = result.datapost.copy()
99+
100+
if isinstance(result.model, PyMCModel):
101+
pre_data["prediction"] = (
102+
az.extract(
103+
result.pre_pred, group="posterior_predictive", var_names="mu"
104+
)
105+
.mean("sample")
106+
.values
107+
)
108+
post_data["prediction"] = (
109+
az.extract(
110+
result.post_pred, group="posterior_predictive", var_names="mu"
111+
)
112+
.mean("sample")
113+
.values
114+
)
115+
pre_data["impact"] = result.pre_impact.mean(dim=["chain", "draw"]).values
116+
post_data["impact"] = result.post_impact.mean(dim=["chain", "draw"]).values
117+
118+
elif isinstance(result.model, RegressorMixin):
119+
pre_data["prediction"] = result.pre_pred
120+
post_data["prediction"] = result.post_pred
121+
pre_data["impact"] = result.pre_impact
122+
post_data["impact"] = result.post_impact
123+
124+
else:
125+
raise ValueError("Other model types are not supported")
126+
127+
ppf_data = pd.concat([pre_data, post_data])
128+
129+
else:
130+
raise ValueError("Other experiments are not supported")
131+
132+
return ppf_data

0 commit comments

Comments
 (0)