|
24 | 24 | import xarray as xr
|
25 | 25 | from matplotlib.collections import PolyCollection
|
26 | 26 | from matplotlib.lines import Line2D
|
| 27 | +from sklearn.base import RegressorMixin |
27 | 28 |
|
28 | 29 |
|
29 | 30 | def plot_xY(
|
@@ -79,3 +80,53 @@ def plot_xY(
|
79 | 80 | filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
|
80 | 81 | )[-1]
|
81 | 82 | return (h_line, h_patch)
|
| 83 | + |
| 84 | + |
| 85 | +def export_prepostfit_data(result): |
| 86 | + """ |
| 87 | + Utility function to recover the data of a PrePostFit experiment along with prediction and causal impact information. |
| 88 | +
|
| 89 | + :param result: |
| 90 | + The result of a PrePostFit experiment |
| 91 | + """ |
| 92 | + |
| 93 | + from causalpy.experiments.prepostfit import PrePostFit |
| 94 | + from causalpy.pymc_models import PyMCModel |
| 95 | + |
| 96 | + if isinstance(result, PrePostFit): |
| 97 | + pre_data = result.datapre.copy() |
| 98 | + post_data = result.datapost.copy() |
| 99 | + |
| 100 | + if isinstance(result.model, PyMCModel): |
| 101 | + pre_data["prediction"] = ( |
| 102 | + az.extract( |
| 103 | + result.pre_pred, group="posterior_predictive", var_names="mu" |
| 104 | + ) |
| 105 | + .mean("sample") |
| 106 | + .values |
| 107 | + ) |
| 108 | + post_data["prediction"] = ( |
| 109 | + az.extract( |
| 110 | + result.post_pred, group="posterior_predictive", var_names="mu" |
| 111 | + ) |
| 112 | + .mean("sample") |
| 113 | + .values |
| 114 | + ) |
| 115 | + pre_data["impact"] = result.pre_impact.mean(dim=["chain", "draw"]).values |
| 116 | + post_data["impact"] = result.post_impact.mean(dim=["chain", "draw"]).values |
| 117 | + |
| 118 | + elif isinstance(result.model, RegressorMixin): |
| 119 | + pre_data["prediction"] = result.pre_pred |
| 120 | + post_data["prediction"] = result.post_pred |
| 121 | + pre_data["impact"] = result.pre_impact |
| 122 | + post_data["impact"] = result.post_impact |
| 123 | + |
| 124 | + else: |
| 125 | + raise ValueError("Other model types are not supported") |
| 126 | + |
| 127 | + ppf_data = pd.concat([pre_data, post_data]) |
| 128 | + |
| 129 | + else: |
| 130 | + raise ValueError("Other experiments are not supported") |
| 131 | + |
| 132 | + return ppf_data |
0 commit comments