Skip to content

Commit 8a9800f

Browse files
committed
pymc models and experiments done and docs checked
1 parent fd92613 commit 8a9800f

File tree

2 files changed

+329
-57
lines changed

2 files changed

+329
-57
lines changed

causalpy/pymc_experiments.py

Lines changed: 230 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""
22
Experiment routines for PyMC models.
33
4-
Includes:
5-
1. ExperimentalDesign base class
6-
2. Pre-Post Fit
7-
3. Synthetic Control
8-
4. Difference in differences
9-
5. Regression Discontinuity
4+
- ExperimentalDesign base class
5+
- Pre-Post Fit
6+
- Interrupted Time Series
7+
- Synthetic Control
8+
- Difference in differences
9+
- Regression Discontinuity
10+
- Pretest/Posttest Nonequivalent Group Design
11+
1012
"""
13+
1114
import warnings
1215
from typing import Optional, Union
1316

@@ -30,7 +33,11 @@
3033

3134

3235
class ExperimentalDesign:
33-
"""Base class for other experiment types"""
36+
"""
37+
Base class for other experiment types
38+
39+
See subclasses for examples of most methods
40+
"""
3441

3542
model = None
3643
expt_type = None
@@ -43,11 +50,63 @@ def __init__(self, model=None, **kwargs):
4350

4451
@property
4552
def idata(self):
46-
"""Access to the models InferenceData object"""
53+
"""
54+
Access to the models InferenceData object
55+
56+
Example
57+
--------
58+
If `result` is the result of the Difference in Differences experiment example
59+
60+
>>> result.idata
61+
Inference data with groups:
62+
> posterior
63+
> posterior_predictive
64+
> sample_stats
65+
> prior
66+
> prior_predictive
67+
> observed_data
68+
> constant_data
69+
>>> result.idata.posterior
70+
<xarray.Dataset>
71+
Dimensions: (chain: 4, draw: 1000, coeffs: 4, obs_ind: 40)
72+
Coordinates:
73+
* chain (chain) int64 0 1 2 3
74+
* draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998
75+
999
76+
* coeffs (coeffs) <U28 'Intercept' ... 'group:post_treatment[T.True]'
77+
* obs_ind (obs_ind) int64 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37
78+
38 39
79+
Data variables:
80+
beta (chain, draw, coeffs) float64 1.04 1.013 0.173 ... 0.1873 0.5225
81+
sigma (chain, draw) float64 0.09331 0.1031 0.1024 ... 0.0824 0.06907
82+
mu (chain, draw, obs_ind) float64 1.04 2.053 1.213 ... 1.265 2.747
83+
Attributes:
84+
created_at: 2023-08-23T20:03:45.709265
85+
arviz_version: 0.16.1
86+
inference_library: pymc
87+
inference_library_version: 5.7.2
88+
sampling_time: 0.8851289749145508
89+
tuning_steps: 1000
90+
"""
91+
4792
return self.model.idata
4893

4994
def print_coefficients(self) -> None:
50-
"""Prints the model coefficients"""
95+
"""
96+
Prints the model coefficients
97+
98+
Example
99+
--------
100+
If `result` is from the Difference in Differences experiment example
101+
102+
>>> result.print_coefficients()
103+
Model coefficients:
104+
Intercept 1.08, 94% HDI [1.03, 1.13]
105+
post_treatment[T.True] 0.98, 94% HDI [0.91, 1.06]
106+
group 0.16, 94% HDI [0.09, 0.23]
107+
group:post_treatment[T.True] 0.51, 94% HDI [0.41, 0.61]
108+
sigma 0.08, 94% HDI [0.07, 0.10]
109+
"""
51110
print("Model coefficients:")
52111
coeffs = az.extract(self.idata.posterior, var_names="beta")
53112
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of
@@ -82,6 +141,7 @@ class PrePostFit(ExperimentalDesign):
82141
Example
83142
--------
84143
>>> sc = cp.load_data("sc")
144+
>>> treatment_time = 70
85145
>>> seed = 42
86146
>>> result = cp.pymc_experiments.PrePostFit(
87147
... sc,
@@ -91,6 +151,17 @@ class PrePostFit(ExperimentalDesign):
91151
... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
92152
... ),
93153
... )
154+
Auto-assigning NUTS sampler...
155+
Initializing NUTS using jitter+adapt_diag...
156+
Multiprocess sampling (4 chains in 4 jobs)
157+
NUTS: [beta, sigma]
158+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations
159+
(4_000 + 4_000 draws total) took 11 seconds.
160+
Sampling: [beta, sigma, y_hat]
161+
Sampling: [y_hat]
162+
Sampling: [y_hat]
163+
Sampling: [y_hat]
164+
Sampling: [y_hat]
94165
"""
95166

96167
def __init__(
@@ -105,6 +176,8 @@ def __init__(
105176
self._input_validation(data, treatment_time)
106177

107178
self.treatment_time = treatment_time
179+
# set experiment type - usually done in subclasses
180+
self.expt_type = "Pre-Post Fit"
108181
# split data in to pre and post intervention
109182
self.datapre = data[data.index <= self.treatment_time]
110183
self.datapost = data[data.index > self.treatment_time]
@@ -171,7 +244,14 @@ def _input_validation(self, data, treatment_time):
171244
)
172245

173246
def plot(self, counterfactual_label="Counterfactual", **kwargs):
174-
"""Plot the results"""
247+
"""
248+
Plot the results
249+
250+
Example
251+
--------
252+
>>> result.plot()
253+
254+
"""
175255
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
176256

177257
# TOP PLOT --------------------------------------------------
@@ -271,7 +351,24 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
271351
return (fig, ax)
272352

273353
def summary(self) -> None:
274-
"""Print text output summarising the results"""
354+
"""
355+
Print text output summarising the results
356+
357+
Example
358+
---------
359+
>>> result.summary()
360+
===============================Synthetic Control===============================
361+
Formula: actual ~ 0 + a + b + c + d + e + f + g
362+
Model coefficients:
363+
a 0.33, 94% HDI [0.30, 0.38]
364+
b 0.05, 94% HDI [0.01, 0.09]
365+
c 0.31, 94% HDI [0.26, 0.35]
366+
d 0.06, 94% HDI [0.01, 0.10]
367+
e 0.02, 94% HDI [0.00, 0.06]
368+
f 0.20, 94% HDI [0.12, 0.26]
369+
g 0.04, 94% HDI [0.00, 0.08]
370+
sigma 0.26, 94% HDI [0.22, 0.30]
371+
"""
275372

276373
print(f"{self.expt_type:=^80}")
277374
print(f"Formula: {self.formula}")
@@ -307,6 +404,17 @@ class InterruptedTimeSeries(PrePostFit):
307404
... formula="y ~ 1 + t + C(month)",
308405
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
309406
... )
407+
Auto-assigning NUTS sampler...
408+
Initializing NUTS using jitter+adapt_diag...
409+
Multiprocess sampling (4 chains in 4 jobs)
410+
NUTS: [beta, sigma]
411+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations
412+
(4_000 + 4_000 draws total) took 3 seconds.
413+
Sampling: [beta, sigma, y_hat]
414+
Sampling: [y_hat]
415+
Sampling: [y_hat]
416+
Sampling: [y_hat]
417+
Sampling: [y_hat]
310418
"""
311419

312420
expt_type = "Interrupted Time Series"
@@ -337,6 +445,17 @@ class SyntheticControl(PrePostFit):
337445
... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
338446
... ),
339447
... )
448+
Auto-assigning NUTS sampler...
449+
Initializing NUTS using jitter+adapt_diag...
450+
Multiprocess sampling (4 chains in 4 jobs)
451+
NUTS: [beta, sigma]
452+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations
453+
(4_000 + 4_000 draws total) took 11 seconds.
454+
Sampling: [beta, sigma, y_hat]
455+
Sampling: [y_hat]
456+
Sampling: [y_hat]
457+
Sampling: [y_hat]
458+
Sampling: [y_hat]
340459
"""
341460

342461
expt_type = "Synthetic Control"
@@ -382,7 +501,17 @@ class DifferenceInDifferences(ExperimentalDesign):
382501
... group_variable_name="group",
383502
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
384503
... )
385-
504+
Auto-assigning NUTS sampler...
505+
Initializing NUTS using jitter+adapt_diag...
506+
Multiprocess sampling (4 chains in 4 jobs)
507+
NUTS: [beta, sigma]
508+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations
509+
(4_000 + 4_000 draws total) took 1 seconds.
510+
Sampling: [beta, sigma, y_hat]
511+
Sampling: [y_hat]
512+
Sampling: [y_hat]
513+
Sampling: [y_hat]
514+
Sampling: [y_hat]
386515
"""
387516

388517
def __init__(
@@ -503,6 +632,12 @@ def _input_validation(self):
503632
def plot(self):
504633
"""Plot the results.
505634
Creating the combined mean + HDI legend entries is a bit involved.
635+
636+
Example
637+
--------
638+
Assuming `result` is the result of a DiD experiment:
639+
640+
>>> result.plot()
506641
"""
507642
fig, ax = plt.subplots()
508643

@@ -639,7 +774,25 @@ def _causal_impact_summary_stat(self) -> str:
639774
return f"Causal impact = {causal_impact + ci}"
640775

641776
def summary(self) -> None:
642-
"""Print text output summarising the results"""
777+
"""
778+
Print text output summarising the results
779+
780+
Example
781+
--------
782+
Assuming `result` is a DiD experiment
783+
784+
>>> result.summary()
785+
==========================Difference in Differences=========================
786+
Formula: y ~ 1 + group*post_treatment
787+
Results:
788+
Causal impact = 0.51, $CI_{94%}$[0.41, 0.61]
789+
Model coefficients:
790+
Intercept 1.08, 94% HDI [1.03, 1.13]
791+
post_treatment[T.True] 0.98, 94% HDI [0.91, 1.06]
792+
group 0.16, 94% HDI [0.09, 0.23]
793+
group:post_treatment[T.True] 0.51, 94% HDI [0.41, 0.61]
794+
sigma 0.08, 94% HDI [0.07, 0.10]
795+
"""
643796

644797
print(f"{self.expt_type:=^80}")
645798
print(f"Formula: {self.formula}")
@@ -680,7 +833,17 @@ class RegressionDiscontinuity(ExperimentalDesign):
680833
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
681834
... treatment_threshold=0.5,
682835
... )
683-
836+
Auto-assigning NUTS sampler...
837+
Initializing NUTS using jitter+adapt_diag...
838+
Multiprocess sampling (4 chains in 4 jobs)
839+
NUTS: [beta, sigma]
840+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations
841+
(4_000 + 4_000 draws total) took 2 seconds.
842+
Sampling: [beta, sigma, y_hat]
843+
Sampling: [y_hat]
844+
Sampling: [y_hat]
845+
Sampling: [y_hat]
846+
Sampling: [y_hat]
684847
"""
685848

686849
def __init__(
@@ -791,7 +954,13 @@ def _is_treated(self, x):
791954
return np.greater_equal(x, self.treatment_threshold)
792955

793956
def plot(self):
794-
"""Plot the results"""
957+
"""
958+
Plot the results
959+
960+
Example
961+
--------
962+
>>> result.plot()
963+
"""
795964
fig, ax = plt.subplots()
796965
# Plot raw data
797966
sns.scatterplot(
@@ -837,7 +1006,25 @@ def plot(self):
8371006
return (fig, ax)
8381007

8391008
def summary(self) -> None:
840-
"""Print text output summarising the results"""
1009+
"""
1010+
Print text output summarising the results
1011+
1012+
Example
1013+
--------
1014+
>>> result.summary()
1015+
============================Regression Discontinuity==========================
1016+
Formula: y ~ 1 + x + treated + x:treated
1017+
Running variable: x
1018+
Threshold on running variable: 0.5
1019+
Results:
1020+
Discontinuity at threshold = 0.92
1021+
Model coefficients:
1022+
Intercept 0.09, 94% HDI [0.00, 0.17]
1023+
treated[T.True] 2.48, 94% HDI [1.66, 3.27]
1024+
x 1.32, 94% HDI [1.14, 1.50]
1025+
x:treated[T.True] -3.12, 94% HDI [-4.17, -2.05]
1026+
sigma 0.35, 94% HDI [0.31, 0.41]
1027+
"""
8411028

8421029
print(f"{self.expt_type:=^80}")
8431030
print(f"Formula: {self.formula}")
@@ -876,7 +1063,16 @@ class PrePostNEGD(ExperimentalDesign):
8761063
... pretreatment_variable_name="pre",
8771064
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
8781065
... )
879-
1066+
Auto-assigning NUTS sampler...
1067+
Initializing NUTS using jitter+adapt_diag...
1068+
Multiprocess sampling (4 chains in 4 jobs)
1069+
NUTS: [beta, sigma]
1070+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations
1071+
(4_000 + 4_000 draws total) took 3 seconds.
1072+
Sampling: [beta, sigma, y_hat]
1073+
Sampling: [y_hat]
1074+
Sampling: [y_hat]
1075+
Sampling: [y_hat]
8801076
"""
8811077

8821078
def __init__(
@@ -1010,7 +1206,23 @@ def _causal_impact_summary_stat(self) -> str:
10101206
return f"Causal impact = {causal_impact + ci}"
10111207

10121208
def summary(self) -> None:
1013-
"""Print text output summarising the results"""
1209+
"""
1210+
Print text output summarising the results
1211+
1212+
Example
1213+
--------
1214+
>>> result.summary()
1215+
=================Pretest/posttest Nonequivalent Group Design================
1216+
Formula: post ~ 1 + C(group) + pre
1217+
Results:
1218+
Causal impact = 1.89, $CI_{94%}$[1.70, 2.07]
1219+
Model coefficients:
1220+
Intercept -0.46, 94% HDI [-1.17, 0.22]
1221+
C(group)[T.1] 1.89, 94% HDI [1.70, 2.07]
1222+
pre 1.05, 94% HDI [0.98, 1.12]
1223+
sigma 0.51, 94% HDI [0.46, 0.56]
1224+
1225+
"""
10141226

10151227
print(f"{self.expt_type:=^80}")
10161228
print(f"Formula: {self.formula}")

0 commit comments

Comments
 (0)