Skip to content

Commit 8f9b9db

Browse files
committed
fix arviz warning, improve type hints
1 parent 33d92b5 commit 8f9b9db

File tree

2 files changed

+54
-23
lines changed

2 files changed

+54
-23
lines changed

causalpy/plot_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1+
from typing import Any, Dict, Optional, Union
2+
13
import arviz as az
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pandas as pd
7+
import xarray as xr
28

39

410
def plot_xY(
5-
x, Y, ax, plot_hdi_kwargs=dict(), hdi_prob: float = 0.94, include_label: bool = True
11+
x: Union[pd.DatetimeIndex, np.array],
12+
Y: xr.DataArray,
13+
ax: plt.Axes,
14+
plot_hdi_kwargs: Optional[Dict[str, Any]] = {},
15+
hdi_prob: Optional[float] = 0.94,
16+
include_label: Optional[bool] = True,
617
) -> None:
718
"""Utility function to plot HDI intervals."""
819

9-
Y = Y.stack(samples=["chain", "draw"]).T
1020
az.plot_hdi(
1121
x,
1222
Y,
@@ -21,7 +31,7 @@ def plot_xY(
2131
)
2232
ax.plot(
2333
x,
24-
Y.mean(dim="samples"),
34+
Y.mean(dim=["chain", "draw"]),
2535
color="k",
2636
label="Posterior mean" if include_label else None,
2737
)

causalpy/pymc_experiments.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,15 @@ def __init__(
103103

104104
# causal impact pre (ie the residuals of the model fit to observed)
105105
pre_data = xr.DataArray(self.pre_y[:, 0], dims=["obs_ind"])
106-
self.pre_impact = pre_data - self.pre_pred["posterior_predictive"].y_hat
106+
self.pre_impact = (
107+
pre_data - self.pre_pred["posterior_predictive"].y_hat
108+
).transpose(..., "obs_ind")
107109

108110
# causal impact post (ie the residuals of the model fit to observed)
109111
post_data = xr.DataArray(self.post_y[:, 0], dims=["obs_ind"])
110-
self.post_impact = post_data - self.post_pred["posterior_predictive"].y_hat
112+
self.post_impact = (
113+
post_data - self.post_pred["posterior_predictive"].y_hat
114+
).transpose(..., "obs_ind")
111115

112116
# cumulative impact post
113117
self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")
@@ -117,9 +121,12 @@ def plot(self):
117121
"""Plot the results"""
118122
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
119123

124+
# TOP PLOT --------------------------------------------------
120125
# pre-intervention period
121126
plot_xY(
122-
self.datapre.index, self.pre_pred["posterior_predictive"].y_hat, ax=ax[0]
127+
self.datapre.index,
128+
self.pre_pred["posterior_predictive"].y_hat,
129+
ax=ax[0],
123130
)
124131
ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
125132
# post intervention period
@@ -130,23 +137,6 @@ def plot(self):
130137
include_label=False,
131138
)
132139
ax[0].plot(self.datapost.index, self.post_y, "k.")
133-
134-
ax[0].set(
135-
title=f"""
136-
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
137-
(std = {self.score.r2_std:.3f})
138-
"""
139-
)
140-
141-
plot_xY(self.datapre.index, self.pre_impact, ax=ax[1])
142-
plot_xY(self.datapost.index, self.post_impact, ax=ax[1], include_label=False)
143-
ax[1].axhline(y=0, c="k")
144-
ax[1].set(title="Causal Impact")
145-
146-
ax[2].set(title="Cumulative Causal Impact")
147-
plot_xY(self.datapost.index, self.post_impact_cumulative, ax=ax[2])
148-
ax[2].axhline(y=0, c="k")
149-
150140
# Shaded causal effect
151141
ax[0].fill_between(
152142
self.datapost.index,
@@ -158,13 +148,44 @@ def plot(self):
158148
alpha=0.25,
159149
label="Causal impact",
160150
)
151+
ax[0].set(
152+
title=f"""
153+
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
154+
(std = {self.score.r2_std:.3f})
155+
"""
156+
)
157+
158+
# MIDDLE PLOT -----------------------------------------------
159+
plot_xY(
160+
self.datapre.index,
161+
self.pre_impact,
162+
ax=ax[1],
163+
)
164+
plot_xY(
165+
self.datapost.index,
166+
self.post_impact,
167+
ax=ax[1],
168+
include_label=False,
169+
)
170+
ax[1].axhline(y=0, c="k")
161171
ax[1].fill_between(
162172
self.datapost.index,
163173
y1=self.post_impact.mean(["chain", "draw"]),
164174
color="C0",
165175
alpha=0.25,
166176
label="Causal impact",
167177
)
178+
ax[1].set(title="Causal Impact")
179+
180+
# BOTTOM PLOT -----------------------------------------------
181+
182+
ax[2].set(title="Cumulative Causal Impact")
183+
plot_xY(
184+
self.datapost.index,
185+
self.post_impact_cumulative,
186+
ax=ax[2],
187+
)
188+
ax[2].axhline(y=0, c="k")
168189

169190
# Intervention line
170191
for i in [0, 1, 2]:

0 commit comments

Comments
 (0)