Skip to content

Commit 14b1fb7

Browse files
committed
#58 add legend to TimeSeriesExperiment plot
1 parent c755caa commit 14b1fb7

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

causalpy/plot_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
import arviz as az
22

33

4-
def plot_xY(x, Y, ax, plot_hdi_kwargs=dict(), hdi_prob: float = 0.94) -> None:
4+
def plot_xY(
5+
x, Y, ax, plot_hdi_kwargs=dict(), hdi_prob: float = 0.94, include_label: bool = True
6+
) -> None:
57
"""Utility function to plot HDI intervals."""
68

79
Y = Y.stack(samples=["chain", "draw"]).T
810
az.plot_hdi(
911
x,
1012
Y,
1113
hdi_prob=hdi_prob,
12-
fill_kwargs={"alpha": 0.25, "label": f"{hdi_prob*100}% HDI"},
14+
fill_kwargs={
15+
"alpha": 0.25,
16+
"label": f"{hdi_prob*100}% HDI" if include_label else None,
17+
},
1318
smooth=False,
1419
ax=ax,
1520
**plot_hdi_kwargs,
1621
)
17-
ax.plot(x, Y.mean(dim="samples"), color="k", label="Posterior mean")
22+
ax.plot(
23+
x,
24+
Y.mean(dim="samples"),
25+
color="k",
26+
label="Posterior mean" if include_label else None,
27+
)

causalpy/pymc_experiments.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,21 @@ def plot(self):
8989
plot_xY(
9090
self.datapre.index, self.pre_pred["posterior_predictive"].y_hat, ax=ax[0]
9191
)
92-
ax[0].plot(self.datapre.index, self.pre_y, "k.")
92+
ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
9393
# post intervention period
9494
plot_xY(
95-
self.datapost.index, self.post_pred["posterior_predictive"].y_hat, ax=ax[0]
95+
self.datapost.index,
96+
self.post_pred["posterior_predictive"].y_hat,
97+
ax=ax[0],
98+
include_label=False,
9699
)
97100
ax[0].plot(self.datapost.index, self.post_y, "k.")
98101
ax[0].set(
99102
title=f"Pre-intervention Bayesian $R^2$: {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
100103
)
101104

102105
plot_xY(self.datapre.index, self.pre_impact, ax=ax[1])
103-
plot_xY(self.datapost.index, self.post_impact, ax=ax[1])
106+
plot_xY(self.datapost.index, self.post_impact, ax=ax[1], include_label=False)
104107
ax[1].axhline(y=0, c="k")
105108
ax[1].set(title="Causal Impact")
106109

@@ -115,8 +118,11 @@ def plot(self):
115118
ls="-",
116119
lw=3,
117120
color="r",
118-
label="treatment time",
121+
label="Treatment time",
119122
)
123+
124+
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
125+
120126
return (fig, ax)
121127

122128

causalpy/skl_experiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def plot(self):
117117
ls="-",
118118
lw=3,
119119
color="r",
120-
label="treatment time",
120+
label="Treatment time",
121121
)
122122

123123
ax[0].legend(fontsize=LEGEND_FONT_SIZE)

0 commit comments

Comments
 (0)