Skip to content

Commit 2b1d82a

Browse files
authored
Merge pull request #145 from pymc-labs/improve-plot-legends
Improve plot clarity with combined mean and HDI legend elements
2 parents 2364c4e + 4a769b4 commit 2b1d82a

17 files changed

+8292
-9515
lines changed

causalpy/plot_utils.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Any, Dict, Optional, Union
1+
from typing import Any, Dict, Optional, Tuple, Union
22

33
import arviz as az
44
import matplotlib.pyplot as plt
55
import numpy as np
66
import pandas as pd
77
import xarray as xr
8+
from matplotlib.collections import PolyCollection
9+
from matplotlib.lines import Line2D
810

911

1012
def plot_xY(
@@ -13,28 +15,35 @@ def plot_xY(
1315
ax: plt.Axes,
1416
plot_hdi_kwargs: Optional[Dict[str, Any]] = None,
1517
hdi_prob: float = 0.94,
16-
include_label: bool = True,
17-
) -> None:
18+
label: Union[str, None] = None,
19+
) -> Tuple[Line2D, PolyCollection]:
1820
"""Utility function to plot HDI intervals."""
1921

2022
if plot_hdi_kwargs is None:
2123
plot_hdi_kwargs = {}
2224

23-
az.plot_hdi(
25+
(h_line,) = ax.plot(
26+
x,
27+
Y.mean(dim=["chain", "draw"]),
28+
ls="-",
29+
**plot_hdi_kwargs,
30+
label=f"{label}",
31+
)
32+
ax_hdi = az.plot_hdi(
2433
x,
2534
Y,
2635
hdi_prob=hdi_prob,
2736
fill_kwargs={
2837
"alpha": 0.25,
29-
"label": f"{hdi_prob*100}% HDI" if include_label else None,
38+
"label": " ",
3039
},
3140
smooth=False,
3241
ax=ax,
3342
**plot_hdi_kwargs,
3443
)
35-
ax.plot(
36-
x,
37-
Y.mean(dim=["chain", "draw"]),
38-
color="k",
39-
label="Posterior mean" if include_label else None,
40-
)
44+
# Return handle to patch. We get a list of the childen of the axis. Filter for just
45+
# the PolyCollection objects. Take the last one.
46+
h_patch = list(
47+
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
48+
)[-1]
49+
return (h_line, h_patch)

causalpy/pymc_experiments.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ def __init__(
9999
# causal impact pre (ie the residuals of the model fit to observed)
100100
pre_data = xr.DataArray(self.pre_y[:, 0], dims=["obs_ind"])
101101
self.pre_impact = (
102-
pre_data - self.pre_pred["posterior_predictive"].y_hat
102+
pre_data - self.pre_pred["posterior_predictive"].mu
103103
).transpose(..., "obs_ind")
104104

105105
# causal impact post (ie the residuals of the model fit to observed)
106106
post_data = xr.DataArray(self.post_y[:, 0], dims=["obs_ind"])
107107
self.post_impact = (
108-
post_data - self.post_pred["posterior_predictive"].y_hat
108+
post_data - self.post_pred["posterior_predictive"].mu
109109
).transpose(..., "obs_ind")
110110

111111
# cumulative impact post
@@ -118,31 +118,43 @@ def plot(self):
118118

119119
# TOP PLOT --------------------------------------------------
120120
# pre-intervention period
121-
plot_xY(
121+
h_line, h_patch = plot_xY(
122122
self.datapre.index,
123-
self.pre_pred["posterior_predictive"].y_hat,
123+
self.pre_pred["posterior_predictive"].mu,
124124
ax=ax[0],
125+
plot_hdi_kwargs={"color": "C0"},
125126
)
126-
ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
127+
handles = [(h_line, h_patch)]
128+
labels = ["Pre-intervention period"]
129+
130+
(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
131+
handles.append(h)
132+
labels.append("Observations")
133+
127134
# post intervention period
128-
plot_xY(
135+
h_line, h_patch = plot_xY(
129136
self.datapost.index,
130-
self.post_pred["posterior_predictive"].y_hat,
137+
self.post_pred["posterior_predictive"].mu,
131138
ax=ax[0],
132-
include_label=False,
139+
plot_hdi_kwargs={"color": "C1"},
133140
)
141+
handles.append((h_line, h_patch))
142+
labels.append("Synthetic control")
143+
134144
ax[0].plot(self.datapost.index, self.post_y, "k.")
135145
# Shaded causal effect
136-
ax[0].fill_between(
146+
h = ax[0].fill_between(
137147
self.datapost.index,
138148
y1=az.extract(
139-
self.post_pred, group="posterior_predictive", var_names="y_hat"
149+
self.post_pred, group="posterior_predictive", var_names="mu"
140150
).mean("sample"),
141151
y2=np.squeeze(self.post_y),
142152
color="C0",
143153
alpha=0.25,
144-
label="Causal impact",
145154
)
155+
handles.append(h)
156+
labels.append("Causal impact")
157+
146158
ax[0].set(
147159
title=f"""
148160
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
@@ -155,12 +167,13 @@ def plot(self):
155167
self.datapre.index,
156168
self.pre_impact,
157169
ax=ax[1],
170+
plot_hdi_kwargs={"color": "C0"},
158171
)
159172
plot_xY(
160173
self.datapost.index,
161174
self.post_impact,
162175
ax=ax[1],
163-
include_label=False,
176+
plot_hdi_kwargs={"color": "C1"},
164177
)
165178
ax[1].axhline(y=0, c="k")
166179
ax[1].fill_between(
@@ -173,12 +186,12 @@ def plot(self):
173186
ax[1].set(title="Causal Impact")
174187

175188
# BOTTOM PLOT -----------------------------------------------
176-
177189
ax[2].set(title="Cumulative Causal Impact")
178190
plot_xY(
179191
self.datapost.index,
180192
self.post_impact_cumulative,
181193
ax=ax[2],
194+
plot_hdi_kwargs={"color": "C1"},
182195
)
183196
ax[2].axhline(y=0, c="k")
184197

@@ -189,10 +202,13 @@ def plot(self):
189202
ls="-",
190203
lw=3,
191204
color="r",
192-
label="Treatment time",
193205
)
194206

195-
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
207+
ax[0].legend(
208+
handles=(h_tuple for h_tuple in handles),
209+
labels=labels,
210+
fontsize=LEGEND_FONT_SIZE,
211+
)
196212

197213
return (fig, ax)
198214

@@ -353,39 +369,46 @@ def __init__(
353369
)
354370

355371
def plot(self):
356-
"""Plot the results"""
372+
"""Plot the results.
373+
Creating the combined mean + HDI legend entries is a bit involved.
374+
"""
357375
fig, ax = plt.subplots()
358376

359377
# Plot raw data
360-
# NOTE: This will not work when there is just ONE unit in each group
361-
sns.lineplot(
378+
sns.scatterplot(
362379
self.data,
363380
x=self.time_variable_name,
364381
y=self.outcome_variable_name,
365382
hue=self.group_variable_name,
366-
units="unit", # NOTE: assumes we have a `unit` predictor variable
367-
estimator=None,
368-
alpha=0.5,
383+
alpha=1,
384+
legend=False,
385+
markers=True,
369386
ax=ax,
370387
)
371388

372389
# Plot model fit to control group
373390
time_points = self.x_pred_control[self.time_variable_name].values
374-
plot_xY(
391+
h_line, h_patch = plot_xY(
375392
time_points,
376-
self.y_pred_control.posterior_predictive.y_hat,
393+
self.y_pred_control.posterior_predictive.mu,
377394
ax=ax,
378395
plot_hdi_kwargs={"color": "C0"},
396+
label="Control group",
379397
)
398+
handles = [(h_line, h_patch)]
399+
labels = ["Control group"]
380400

381401
# Plot model fit to treatment group
382402
time_points = self.x_pred_control[self.time_variable_name].values
383-
plot_xY(
403+
h_line, h_patch = plot_xY(
384404
time_points,
385-
self.y_pred_treatment.posterior_predictive.y_hat,
405+
self.y_pred_treatment.posterior_predictive.mu,
386406
ax=ax,
387407
plot_hdi_kwargs={"color": "C1"},
408+
label="Treatment group",
388409
)
410+
handles.append((h_line, h_patch))
411+
labels.append("Treatment group")
389412

390413
# Plot counterfactual - post-test for treatment group IF no treatment
391414
# had occurred.
@@ -403,26 +426,34 @@ def plot(self):
403426
widths=0.2,
404427
)
405428
for pc in parts["bodies"]:
406-
pc.set_facecolor("C2")
429+
pc.set_facecolor("C0")
407430
pc.set_edgecolor("None")
408431
pc.set_alpha(0.5)
409432
else:
410-
plot_xY(
433+
h_line, h_patch = plot_xY(
411434
time_points,
412-
self.y_pred_counterfactual.posterior_predictive.y_hat,
435+
self.y_pred_counterfactual.posterior_predictive.mu,
413436
ax=ax,
414437
plot_hdi_kwargs={"color": "C2"},
438+
label="Counterfactual",
415439
)
440+
handles.append((h_line, h_patch))
441+
labels.append("Counterfactual")
416442

417443
# arrow to label the causal impact
418444
self._plot_causal_impact_arrow(ax)
445+
419446
# formatting
420447
ax.set(
421448
xticks=self.x_pred_treatment[self.time_variable_name].values,
422449
title=self._causal_impact_summary_stat(),
423450
)
424-
ax.legend(fontsize=LEGEND_FONT_SIZE)
425-
return (fig, ax)
451+
ax.legend(
452+
handles=(h_tuple for h_tuple in handles),
453+
labels=labels,
454+
fontsize=LEGEND_FONT_SIZE,
455+
)
456+
return fig, ax
426457

427458
def _plot_causal_impact_arrow(self, ax):
428459
"""
@@ -582,12 +613,17 @@ def plot(self):
582613
c="k", # hue="treated",
583614
ax=ax,
584615
)
616+
585617
# Plot model fit to data
586-
plot_xY(
618+
h_line, h_patch = plot_xY(
587619
self.x_pred[self.running_variable_name],
588620
self.pred["posterior_predictive"].mu,
589621
ax=ax,
622+
plot_hdi_kwargs={"color": "C1"},
590623
)
624+
handles = [(h_line, h_patch)]
625+
labels = ["Posterior mean"]
626+
591627
# create strings to compose title
592628
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
593629
r2 = f"Bayesian $R^2$ on all data = {title_info}"
@@ -605,7 +641,11 @@ def plot(self):
605641
color="r",
606642
label="treatment threshold",
607643
)
608-
ax.legend(fontsize=LEGEND_FONT_SIZE)
644+
ax.legend(
645+
handles=(h_tuple for h_tuple in handles),
646+
labels=labels,
647+
fontsize=LEGEND_FONT_SIZE,
648+
)
609649
return (fig, ax)
610650

611651
def summary(self):
@@ -710,27 +750,38 @@ def plot(self):
710750
hue="group",
711751
alpha=0.5,
712752
data=self.data,
753+
legend=True,
713754
ax=ax[0],
714755
)
715756
ax[0].set(xlabel="Pretest", ylabel="Posttest")
716757

717758
# plot posterior predictive of untreated
718-
plot_xY(
759+
h_line, h_patch = plot_xY(
719760
self.pred_xi,
720-
self.pred_untreated["posterior_predictive"].y_hat,
761+
self.pred_untreated["posterior_predictive"].mu,
721762
ax=ax[0],
722763
plot_hdi_kwargs={"color": "C0"},
764+
label="Control group",
723765
)
766+
handles = [(h_line, h_patch)]
767+
labels = ["Control group"]
724768

725769
# plot posterior predictive of treated
726-
plot_xY(
770+
h_line, h_patch = plot_xY(
727771
self.pred_xi,
728-
self.pred_treated["posterior_predictive"].y_hat,
772+
self.pred_treated["posterior_predictive"].mu,
729773
ax=ax[0],
730774
plot_hdi_kwargs={"color": "C1"},
775+
label="Treatment group",
731776
)
777+
handles.append((h_line, h_patch))
778+
labels.append("Treatment group")
732779

733-
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
780+
ax[0].legend(
781+
handles=(h_tuple for h_tuple in handles),
782+
labels=labels,
783+
fontsize=LEGEND_FONT_SIZE,
784+
)
734785

735786
# Plot estimated caual impact / treatment effect
736787
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1])

docs/notebooks/ancova_pymc.ipynb

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc.ipynb

Lines changed: 11 additions & 4 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 29 additions & 29 deletions
Large diffs are not rendered by default.

docs/notebooks/generate_plots.ipynb

Lines changed: 246 additions & 24 deletions
Large diffs are not rendered by default.

docs/notebooks/geolift1.ipynb

Lines changed: 14 additions & 16 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)