Skip to content

Commit c98cfd9

Browse files
committed
improve plots + plot posterior mu rather than posterior predictive
1 parent 2364c4e commit c98cfd9

12 files changed

+393
-506
lines changed

causalpy/plot_utils.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
import xarray as xr
8+
from matplotlib.collections import PolyCollection
89

910

1011
def plot_xY(
@@ -13,28 +14,46 @@ def plot_xY(
1314
ax: plt.Axes,
1415
plot_hdi_kwargs: Optional[Dict[str, Any]] = None,
1516
hdi_prob: float = 0.94,
17+
label: Optional[str] = "",
1618
include_label: bool = True,
17-
) -> None:
19+
):
1820
"""Utility function to plot HDI intervals."""
1921

2022
if plot_hdi_kwargs is None:
2123
plot_hdi_kwargs = {}
2224

23-
az.plot_hdi(
25+
(h_line,) = ax.plot(
26+
x,
27+
Y.mean(dim=["chain", "draw"]),
28+
ls="-",
29+
**plot_hdi_kwargs,
30+
label=f"{label}" if include_label else None,
31+
)
32+
ax_hdi = az.plot_hdi(
2433
x,
2534
Y,
2635
hdi_prob=hdi_prob,
2736
fill_kwargs={
2837
"alpha": 0.25,
29-
"label": f"{hdi_prob*100}% HDI" if include_label else None,
38+
"label": " ", # f"{hdi_prob*100}% HDI" if include_label else None,
3039
},
3140
smooth=False,
3241
ax=ax,
3342
**plot_hdi_kwargs,
3443
)
35-
ax.plot(
36-
x,
37-
Y.mean(dim=["chain", "draw"]),
38-
color="k",
39-
label="Posterior mean" if include_label else None,
40-
)
44+
# Return handle to patch.
45+
# We get a list of the childen of the axis
46+
# Filter for just the PolyCollection objects
47+
# Take the last one
48+
h_patch = list(
49+
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
50+
)[-1]
51+
52+
# if include_label:
53+
# handles, labels = ax.get_legend_handles_labels()
54+
# ax.legend(
55+
# handles=[(h1, h2) for h1, h2 in zip(handles[::2], handles[1::2])],
56+
# # labels=[l1 + " + " + l2 for l1, l2 in zip(labels[::2], labels[1::2])],
57+
# labels=[l1 for l1 in labels[::2]],
58+
# )
59+
return h_line, h_patch

causalpy/pymc_experiments.py

Lines changed: 95 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ def __init__(
9999
# causal impact pre (ie the residuals of the model fit to observed)
100100
pre_data = xr.DataArray(self.pre_y[:, 0], dims=["obs_ind"])
101101
self.pre_impact = (
102-
pre_data - self.pre_pred["posterior_predictive"].y_hat
102+
pre_data - self.pre_pred["posterior_predictive"].mu
103103
).transpose(..., "obs_ind")
104104

105105
# causal impact post (ie the residuals of the model fit to observed)
106106
post_data = xr.DataArray(self.post_y[:, 0], dims=["obs_ind"])
107107
self.post_impact = (
108-
post_data - self.post_pred["posterior_predictive"].y_hat
108+
post_data - self.post_pred["posterior_predictive"].mu
109109
).transpose(..., "obs_ind")
110110

111111
# cumulative impact post
@@ -118,31 +118,47 @@ def plot(self):
118118

119119
# TOP PLOT --------------------------------------------------
120120
# pre-intervention period
121-
plot_xY(
121+
h_line, h_patch = plot_xY(
122122
self.datapre.index,
123-
self.pre_pred["posterior_predictive"].y_hat,
123+
self.pre_pred["posterior_predictive"].mu,
124124
ax=ax[0],
125+
include_label=False,
126+
plot_hdi_kwargs={"color": "C0"},
125127
)
126-
ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
128+
handles = [(h_line, h_patch)]
129+
labels = ["Pre-intervention period"]
130+
131+
(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
132+
handles.append(h)
133+
labels.append("Observations")
134+
127135
# post intervention period
128-
plot_xY(
136+
h_line, h_patch = plot_xY(
129137
self.datapost.index,
130-
self.post_pred["posterior_predictive"].y_hat,
138+
self.post_pred["posterior_predictive"].mu,
131139
ax=ax[0],
132140
include_label=False,
141+
# label="Synthetic control",
142+
plot_hdi_kwargs={"color": "C1"},
133143
)
144+
handles.append((h_line, h_patch))
145+
labels.append("Synthetic control")
146+
134147
ax[0].plot(self.datapost.index, self.post_y, "k.")
135148
# Shaded causal effect
136-
ax[0].fill_between(
149+
h = ax[0].fill_between(
137150
self.datapost.index,
138151
y1=az.extract(
139-
self.post_pred, group="posterior_predictive", var_names="y_hat"
152+
self.post_pred, group="posterior_predictive", var_names="mu"
140153
).mean("sample"),
141154
y2=np.squeeze(self.post_y),
142-
color="C0",
155+
color="C2",
143156
alpha=0.25,
144-
label="Causal impact",
157+
# label="Causal impact",
145158
)
159+
handles.append(h)
160+
labels.append("Causal impact")
161+
146162
ax[0].set(
147163
title=f"""
148164
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
@@ -155,30 +171,34 @@ def plot(self):
155171
self.datapre.index,
156172
self.pre_impact,
157173
ax=ax[1],
174+
include_label=False,
175+
plot_hdi_kwargs={"color": "C0"},
158176
)
159177
plot_xY(
160178
self.datapost.index,
161179
self.post_impact,
162180
ax=ax[1],
163181
include_label=False,
182+
plot_hdi_kwargs={"color": "C1"},
164183
)
165184
ax[1].axhline(y=0, c="k")
166185
ax[1].fill_between(
167186
self.datapost.index,
168187
y1=self.post_impact.mean(["chain", "draw"]),
169-
color="C0",
188+
color="C2",
170189
alpha=0.25,
171190
label="Causal impact",
172191
)
173192
ax[1].set(title="Causal Impact")
174193

175194
# BOTTOM PLOT -----------------------------------------------
176-
177195
ax[2].set(title="Cumulative Causal Impact")
178196
plot_xY(
179197
self.datapost.index,
180198
self.post_impact_cumulative,
181199
ax=ax[2],
200+
include_label=False,
201+
plot_hdi_kwargs={"color": "C1"},
182202
)
183203
ax[2].axhline(y=0, c="k")
184204

@@ -189,10 +209,14 @@ def plot(self):
189209
ls="-",
190210
lw=3,
191211
color="r",
192-
label="Treatment time",
212+
# label="Treatment time",
193213
)
194214

195-
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
215+
ax[0].legend(
216+
handles=(h_tuple for h_tuple in handles),
217+
labels=labels,
218+
fontsize=LEGEND_FONT_SIZE,
219+
)
196220

197221
return (fig, ax)
198222

@@ -353,39 +377,46 @@ def __init__(
353377
)
354378

355379
def plot(self):
356-
"""Plot the results"""
380+
"""Plot the results.
381+
Creating the combined mean + HDI legend entries is a bit involved.
382+
"""
357383
fig, ax = plt.subplots()
358384

359385
# Plot raw data
360-
# NOTE: This will not work when there is just ONE unit in each group
361-
sns.lineplot(
386+
sns.scatterplot(
362387
self.data,
363388
x=self.time_variable_name,
364389
y=self.outcome_variable_name,
365390
hue=self.group_variable_name,
366-
units="unit", # NOTE: assumes we have a `unit` predictor variable
367-
estimator=None,
368-
alpha=0.5,
391+
alpha=1,
392+
legend=False,
393+
markers=True,
369394
ax=ax,
370395
)
371396

372397
# Plot model fit to control group
373398
time_points = self.x_pred_control[self.time_variable_name].values
374-
plot_xY(
399+
h_line, h_patch = plot_xY(
375400
time_points,
376-
self.y_pred_control.posterior_predictive.y_hat,
401+
self.y_pred_control.posterior_predictive.mu,
377402
ax=ax,
378403
plot_hdi_kwargs={"color": "C0"},
404+
label="Control group",
379405
)
406+
handles = [(h_line, h_patch)]
407+
labels = ["Control group"]
380408

381409
# Plot model fit to treatment group
382410
time_points = self.x_pred_control[self.time_variable_name].values
383-
plot_xY(
411+
h_line, h_patch = plot_xY(
384412
time_points,
385-
self.y_pred_treatment.posterior_predictive.y_hat,
413+
self.y_pred_treatment.posterior_predictive.mu,
386414
ax=ax,
387415
plot_hdi_kwargs={"color": "C1"},
416+
label="Treatment group",
388417
)
418+
handles.append((h_line, h_patch))
419+
labels.append("Treatment group")
389420

390421
# Plot counterfactual - post-test for treatment group IF no treatment
391422
# had occurred.
@@ -407,22 +438,30 @@ def plot(self):
407438
pc.set_edgecolor("None")
408439
pc.set_alpha(0.5)
409440
else:
410-
plot_xY(
441+
h_line, h_patch = plot_xY(
411442
time_points,
412-
self.y_pred_counterfactual.posterior_predictive.y_hat,
443+
self.y_pred_counterfactual.posterior_predictive.mu,
413444
ax=ax,
414445
plot_hdi_kwargs={"color": "C2"},
446+
label="Counterfactual",
415447
)
448+
handles.append((h_line, h_patch))
449+
labels.append("Counterfactual")
416450

417451
# arrow to label the causal impact
418452
self._plot_causal_impact_arrow(ax)
453+
419454
# formatting
420455
ax.set(
421456
xticks=self.x_pred_treatment[self.time_variable_name].values,
422457
title=self._causal_impact_summary_stat(),
423458
)
424-
ax.legend(fontsize=LEGEND_FONT_SIZE)
425-
return (fig, ax)
459+
ax.legend(
460+
handles=(h_tuple for h_tuple in handles),
461+
labels=labels,
462+
fontsize=LEGEND_FONT_SIZE,
463+
)
464+
return fig, ax
426465

427466
def _plot_causal_impact_arrow(self, ax):
428467
"""
@@ -582,12 +621,17 @@ def plot(self):
582621
c="k", # hue="treated",
583622
ax=ax,
584623
)
624+
585625
# Plot model fit to data
586-
plot_xY(
626+
h_line, h_patch = plot_xY(
587627
self.x_pred[self.running_variable_name],
588628
self.pred["posterior_predictive"].mu,
589629
ax=ax,
630+
plot_hdi_kwargs={"color": "C1"},
590631
)
632+
handles = [(h_line, h_patch)]
633+
labels = ["Posterior mean"]
634+
591635
# create strings to compose title
592636
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
593637
r2 = f"Bayesian $R^2$ on all data = {title_info}"
@@ -605,7 +649,11 @@ def plot(self):
605649
color="r",
606650
label="treatment threshold",
607651
)
608-
ax.legend(fontsize=LEGEND_FONT_SIZE)
652+
ax.legend(
653+
handles=(h_tuple for h_tuple in handles),
654+
labels=labels,
655+
fontsize=LEGEND_FONT_SIZE,
656+
)
609657
return (fig, ax)
610658

611659
def summary(self):
@@ -710,27 +758,38 @@ def plot(self):
710758
hue="group",
711759
alpha=0.5,
712760
data=self.data,
761+
legend=True,
713762
ax=ax[0],
714763
)
715764
ax[0].set(xlabel="Pretest", ylabel="Posttest")
716765

717766
# plot posterior predictive of untreated
718-
plot_xY(
767+
h_line, h_patch = plot_xY(
719768
self.pred_xi,
720-
self.pred_untreated["posterior_predictive"].y_hat,
769+
self.pred_untreated["posterior_predictive"].mu,
721770
ax=ax[0],
722771
plot_hdi_kwargs={"color": "C0"},
772+
label="Control group",
723773
)
774+
handles = [(h_line, h_patch)]
775+
labels = ["Control group"]
724776

725777
# plot posterior predictive of treated
726-
plot_xY(
778+
h_line, h_patch = plot_xY(
727779
self.pred_xi,
728-
self.pred_treated["posterior_predictive"].y_hat,
780+
self.pred_treated["posterior_predictive"].mu,
729781
ax=ax[0],
730782
plot_hdi_kwargs={"color": "C1"},
783+
label="Treatment group",
731784
)
785+
handles.append((h_line, h_patch))
786+
labels.append("Treatment group")
732787

733-
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
788+
ax[0].legend(
789+
handles=(h_tuple for h_tuple in handles),
790+
labels=labels,
791+
fontsize=LEGEND_FONT_SIZE,
792+
)
734793

735794
# Plot estimated caual impact / treatment effect
736795
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1])

docs/notebooks/ancova_pymc.ipynb

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc.ipynb

Lines changed: 11 additions & 4 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 28 additions & 26 deletions
Large diffs are not rendered by default.

docs/notebooks/geolift1.ipynb

Lines changed: 12 additions & 12 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)