Skip to content

Commit 21b174e

Browse files
authored
Merge pull request #63 from pymc-labs/bunch-of-changes
Misc bunch of changes
2 parents a952dc7 + b31397d commit 21b174e

File tree

10 files changed

+4431
-3535
lines changed

10 files changed

+4431
-3535
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ This is appropriate when you have multiple units, one of which is treated. You b
5555
|--|--|
5656
| ![](img/synthetic_control_skl.svg) | ![](img/synthetic_control_pymc.svg) |
5757

58-
> The data (treated and untreated units), pre-treatment model fit, and counterfactual (i.e. the synthetic control) are plotted (top). The Frequentist analysis shows the causal impact as a blue shaded region, but this is not shown in the Bayesian analysis to avoid a cluttered chart. Instead, the Bayesian analysis shows shaded Bayesian credible regions of the model fit and counterfactual. Also shown is the causal impact (middle) and cumulative causal impact (bottom).
58+
> The data (treated and untreated units), pre-treatment model fit, and counterfactual (i.e. the synthetic control) are plotted (top). The causal impact is shown as a blue shaded region. The Bayesian analysis shows shaded Bayesian credible regions of the model fit and counterfactual. Also shown is the causal impact (middle) and cumulative causal impact (bottom).
5959
6060
### Interrupted time series
6161
This is appropriate when you have a single treated unit, and therefore a single time series, and do _not_ have a set of untreated units.
@@ -71,7 +71,7 @@ This is appropriate when you have a single treated unit, and therefore a single
7171
|--|--|
7272
| ![](img/interrupted_time_series_skl.svg) | ![](img/interrupted_time_series_pymc.svg) |
7373

74-
> The data, model fits, and counterfactual are plotted (top panels). The Frequentist analysis shows the causal impact with the blue shaded region, but this is not shown in the Bayesian analysis to avoid a cluttered chart. Instead, the Bayesian analysis shows shaded Bayesian credible regions of the model fits. Also shown is the causal impact (middle) and cumulative causal impact (bottom).
74+
> The data (treated and untreated units), pre-treatment model fit, and counterfactual (i.e. the synthetic control) are plotted (top). The causal impact is shown as a blue shaded region. The Bayesian analysis shows shaded Bayesian credible regions of the model fit and counterfactual. Also shown is the causal impact (middle) and cumulative causal impact (bottom).
7575
7676
### Difference in Differences
7777

causalpy/plot_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
import arviz as az
22

33

4-
def plot_xY(x, Y, ax, plot_hdi_kwargs=dict(), hdi_prob: float = 0.94) -> None:
4+
def plot_xY(
5+
x, Y, ax, plot_hdi_kwargs=dict(), hdi_prob: float = 0.94, include_label: bool = True
6+
) -> None:
57
"""Utility function to plot HDI intervals."""
68

79
Y = Y.stack(samples=["chain", "draw"]).T
810
az.plot_hdi(
911
x,
1012
Y,
1113
hdi_prob=hdi_prob,
12-
fill_kwargs={"alpha": 0.25, "label": f"{hdi_prob*100}% HDI"},
14+
fill_kwargs={
15+
"alpha": 0.25,
16+
"label": f"{hdi_prob*100}% HDI" if include_label else None,
17+
},
1318
smooth=False,
1419
ax=ax,
1520
**plot_hdi_kwargs,
1621
)
17-
ax.plot(x, Y.mean(dim="samples"), color="k", label="Posterior mean")
22+
ax.plot(
23+
x,
24+
Y.mean(dim="samples"),
25+
color="k",
26+
label="Posterior mean" if include_label else None,
27+
)

causalpy/pymc_experiments.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,46 @@ def plot(self):
8989
plot_xY(
9090
self.datapre.index, self.pre_pred["posterior_predictive"].y_hat, ax=ax[0]
9191
)
92-
ax[0].plot(self.datapre.index, self.pre_y, "k.")
92+
ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
9393
# post intervention period
9494
plot_xY(
95-
self.datapost.index, self.post_pred["posterior_predictive"].y_hat, ax=ax[0]
95+
self.datapost.index,
96+
self.post_pred["posterior_predictive"].y_hat,
97+
ax=ax[0],
98+
include_label=False,
9699
)
97100
ax[0].plot(self.datapost.index, self.post_y, "k.")
98101
ax[0].set(
99102
title=f"Pre-intervention Bayesian $R^2$: {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
100103
)
101104

102105
plot_xY(self.datapre.index, self.pre_impact, ax=ax[1])
103-
plot_xY(self.datapost.index, self.post_impact, ax=ax[1])
106+
plot_xY(self.datapost.index, self.post_impact, ax=ax[1], include_label=False)
104107
ax[1].axhline(y=0, c="k")
105108
ax[1].set(title="Causal Impact")
106109

107110
ax[2].set(title="Cumulative Causal Impact")
108111
plot_xY(self.datapost.index, self.post_impact_cumulative, ax=ax[2])
112+
ax[2].axhline(y=0, c="k")
113+
114+
# Shaded causal effect
115+
ax[0].fill_between(
116+
self.datapost.index,
117+
y1=az.extract(
118+
self.post_pred, group="posterior_predictive", var_names="y_hat"
119+
).mean("sample"),
120+
y2=np.squeeze(self.post_y),
121+
color="C0",
122+
alpha=0.25,
123+
label="Causal impact",
124+
)
125+
ax[1].fill_between(
126+
self.datapost.index,
127+
y1=self.post_impact.mean(["chain", "draw"]),
128+
color="C0",
129+
alpha=0.25,
130+
label="Causal impact",
131+
)
109132

110133
# Intervention line
111134
for i in [0, 1, 2]:
@@ -114,8 +137,11 @@ def plot(self):
114137
ls="-",
115138
lw=3,
116139
color="r",
117-
label="treatment time",
140+
label="Treatment time",
118141
)
142+
143+
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
144+
119145
return (fig, ax)
120146

121147

causalpy/skl_experiments.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def plot(self):
9999
y2=np.squeeze(self.post_y),
100100
color="C0",
101101
alpha=0.25,
102-
label="causal impact",
102+
label="Causal impact",
103103
)
104104
ax[1].fill_between(
105105
self.datapost.index,
106106
y1=np.squeeze(self.post_impact),
107107
color="C0",
108108
alpha=0.25,
109-
label="causal impact",
109+
label="Causal impact",
110110
)
111111

112112
# Intervention line
@@ -117,7 +117,7 @@ def plot(self):
117117
ls="-",
118118
lw=3,
119119
color="r",
120-
label="treatment time",
120+
label="Treatment time",
121121
)
122122

123123
ax[0].legend(fontsize=LEGEND_FONT_SIZE)

docs/index.rst

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@ CausalPy - Causal inference in quasi-experimental settings
88

99
A Python package focussing on causal inference in quasi-experimental settings. The package allows for sophisticated Bayesian model fitting methods to be used in addition to traditional OLS.
1010

11+
Support
12+
-------
13+
14+
This repository is supported by `PyMC Labs <https://www.pymc-labs.io>`_.
15+
16+
.. image:: ../img/pymc-labs-log.png
17+
:align: center
18+
:target: https://www.pymc-labs.io
19+
:scale: 50 %
20+
21+
1122
.. toctree::
1223
:maxdepth: 2
1324
:caption: Contents:
@@ -33,9 +44,7 @@ Documentation outline
3344
api_plot_utils
3445

3546

36-
Indices and tables
37-
==================
47+
Index
48+
=====
3849

3950
* :ref:`genindex`
40-
* :ref:`modindex`
41-
* :ref:`search`

docs/notebooks/pymc_demos.ipynb

Lines changed: 15 additions & 15 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)