Skip to content

Commit fdc867f

Browse files
committed
add function to prepostfit class + fixes in plot_utils
1 parent 70110c0 commit fdc867f

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

causalpy/experiments/prepostfit.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,63 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303

304304
return (fig, ax)
305305

306+
def get_plot_data(self) -> pd.DataFrame:
307+
"""Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
308+
309+
Internally, this function dispatches to either `get_plot_data_bayesian` or `get_plot_data_ols`
310+
depending on the model type.
311+
"""
312+
if isinstance(self.model, PyMCModel):
313+
return self.get_plot_data_bayesian()
314+
elif isinstance(self.model, RegressorMixin):
315+
return self.get_plot_data_ols()
316+
else:
317+
raise ValueError("Unsupported model type")
318+
319+
def get_plot_data_bayesian(self) -> pd.DataFrame:
320+
"""
321+
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
322+
"""
323+
if isinstance(self.model, PyMCModel):
324+
pre_data = self.datapre.copy()
325+
post_data = self.datapost.copy()
326+
pre_data["prediction"] = (
327+
az.extract(
328+
self.pre_pred, group="posterior_predictive", var_names="mu"
329+
)
330+
.mean("sample")
331+
.values
332+
)
333+
post_data["prediction"] = (
334+
az.extract(
335+
self.post_pred, group="posterior_predictive", var_names="mu"
336+
)
337+
.mean("sample")
338+
.values
339+
)
340+
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
341+
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
342+
343+
self.data_plot = pd.concat([pre_data, post_data])
344+
345+
return self.data_plot
346+
else:
347+
raise ValueError("Unsupported model type")
348+
349+
def get_plot_data_ols(self) -> pd.DataFrame:
350+
"""
351+
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
352+
"""
353+
pre_data = self.datapre.copy()
354+
post_data = self.datapost.copy()
355+
pre_data["prediction"] = self.pre_pred
356+
post_data["prediction"] = self.post_pred
357+
pre_data["impact"] = self.pre_impact
358+
post_data["impact"] = self.post_impact
359+
self.data_plot = pd.concat([pre_data, post_data])
360+
361+
return self.data_plot
362+
306363

307364
class InterruptedTimeSeries(PrePostFit):
308365
"""

causalpy/plot_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def plot_xY(
8282
return (h_line, h_patch)
8383

8484

85-
def export_prepostfit_data(result):
85+
def get_prepostfit_data(result) -> pd.DataFrame:
8686
"""
87-
Utility function to recover the data of a PrePostFit experiment along with prediction and causal impact information.
87+
Utility function to recover the data of a PrePostFit experiment along with the prediction and causal impact information.
8888
8989
:param result:
9090
The result of a PrePostFit experiment

0 commit comments

Comments
 (0)