1
1
"""
2
2
Experiment routines for PyMC models.
3
3
4
- Includes:
5
- 1. ExperimentalDesign base class
6
- 2. Pre-Post Fit
7
- 3. Synthetic Control
8
- 4. Difference in differences
9
- 5. Regression Discontinuity
4
+ - ExperimentalDesign base class
5
+ - Pre-Post Fit
6
+ - Interrupted Time Series
7
+ - Synthetic Control
8
+ - Difference in differences
9
+ - Regression Discontinuity
10
+ - Pretest/Posttest Nonequivalent Group Design
11
+
10
12
"""
13
+
11
14
import warnings
12
15
from typing import Optional , Union
13
16
30
33
31
34
32
35
class ExperimentalDesign :
33
- """Base class for other experiment types"""
36
+ """
37
+ Base class for other experiment types
38
+
39
+ See subclasses for examples of most methods
40
+ """
34
41
35
42
model = None
36
43
expt_type = None
@@ -43,11 +50,63 @@ def __init__(self, model=None, **kwargs):
43
50
44
51
@property
45
52
def idata (self ):
46
- """Access to the models InferenceData object"""
53
+ """
54
+ Access to the models InferenceData object
55
+
56
+ Example
57
+ --------
58
+ If `result` is the result of the Difference in Differences experiment example
59
+
60
+ >>> result.idata
61
+ Inference data with groups:
62
+ > posterior
63
+ > posterior_predictive
64
+ > sample_stats
65
+ > prior
66
+ > prior_predictive
67
+ > observed_data
68
+ > constant_data
69
+ >>> result.idata.posterior
70
+ <xarray.Dataset>
71
+ Dimensions: (chain: 4, draw: 1000, coeffs: 4, obs_ind: 40)
72
+ Coordinates:
73
+ * chain (chain) int64 0 1 2 3
74
+ * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998
75
+ 999
76
+ * coeffs (coeffs) <U28 'Intercept' ... 'group:post_treatment[T.True]'
77
+ * obs_ind (obs_ind) int64 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37
78
+ 38 39
79
+ Data variables:
80
+ beta (chain, draw, coeffs) float64 1.04 1.013 0.173 ... 0.1873 0.5225
81
+ sigma (chain, draw) float64 0.09331 0.1031 0.1024 ... 0.0824 0.06907
82
+ mu (chain, draw, obs_ind) float64 1.04 2.053 1.213 ... 1.265 2.747
83
+ Attributes:
84
+ created_at: 2023-08-23T20:03:45.709265
85
+ arviz_version: 0.16.1
86
+ inference_library: pymc
87
+ inference_library_version: 5.7.2
88
+ sampling_time: 0.8851289749145508
89
+ tuning_steps: 1000
90
+ """
91
+
47
92
return self .model .idata
48
93
49
94
def print_coefficients (self ) -> None :
50
- """Prints the model coefficients"""
95
+ """
96
+ Prints the model coefficients
97
+
98
+ Example
99
+ --------
100
+ If `result` is from the Difference in Differences experiment example
101
+
102
+ >>> result.print_coefficients()
103
+ Model coefficients:
104
+ Intercept 1.08, 94% HDI [1.03, 1.13]
105
+ post_treatment[T.True] 0.98, 94% HDI [0.91, 1.06]
106
+ group 0.16, 94% HDI [0.09, 0.23]
107
+ group:post_treatment[T.True] 0.51, 94% HDI [0.41, 0.61]
108
+ sigma 0.08, 94% HDI [0.07, 0.10]
109
+ """
51
110
print ("Model coefficients:" )
52
111
coeffs = az .extract (self .idata .posterior , var_names = "beta" )
53
112
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of
@@ -82,6 +141,7 @@ class PrePostFit(ExperimentalDesign):
82
141
Example
83
142
--------
84
143
>>> sc = cp.load_data("sc")
144
+ >>> treatment_time = 70
85
145
>>> seed = 42
86
146
>>> result = cp.pymc_experiments.PrePostFit(
87
147
... sc,
@@ -91,6 +151,17 @@ class PrePostFit(ExperimentalDesign):
91
151
... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
92
152
... ),
93
153
... )
154
+ Auto-assigning NUTS sampler...
155
+ Initializing NUTS using jitter+adapt_diag...
156
+ Multiprocess sampling (4 chains in 4 jobs)
157
+ NUTS: [beta, sigma]
158
+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
159
+ (4_000 + 4_000 draws total) took 11 seconds.
160
+ Sampling: [beta, sigma, y_hat]
161
+ Sampling: [y_hat]
162
+ Sampling: [y_hat]
163
+ Sampling: [y_hat]
164
+ Sampling: [y_hat]
94
165
"""
95
166
96
167
def __init__ (
@@ -105,6 +176,8 @@ def __init__(
105
176
self ._input_validation (data , treatment_time )
106
177
107
178
self .treatment_time = treatment_time
179
+ # set experiment type - usually done in subclasses
180
+ self .expt_type = "Pre-Post Fit"
108
181
# split data in to pre and post intervention
109
182
self .datapre = data [data .index <= self .treatment_time ]
110
183
self .datapost = data [data .index > self .treatment_time ]
@@ -171,7 +244,14 @@ def _input_validation(self, data, treatment_time):
171
244
)
172
245
173
246
def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
174
- """Plot the results"""
247
+ """
248
+ Plot the results
249
+
250
+ Example
251
+ --------
252
+ >>> result.plot()
253
+
254
+ """
175
255
fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
176
256
177
257
# TOP PLOT --------------------------------------------------
@@ -271,7 +351,24 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
271
351
return (fig , ax )
272
352
273
353
def summary (self ) -> None :
274
- """Print text output summarising the results"""
354
+ """
355
+ Print text output summarising the results
356
+
357
+ Example
358
+ ---------
359
+ >>> result.summary()
360
+ ===============================Synthetic Control===============================
361
+ Formula: actual ~ 0 + a + b + c + d + e + f + g
362
+ Model coefficients:
363
+ a 0.33, 94% HDI [0.30, 0.38]
364
+ b 0.05, 94% HDI [0.01, 0.09]
365
+ c 0.31, 94% HDI [0.26, 0.35]
366
+ d 0.06, 94% HDI [0.01, 0.10]
367
+ e 0.02, 94% HDI [0.00, 0.06]
368
+ f 0.20, 94% HDI [0.12, 0.26]
369
+ g 0.04, 94% HDI [0.00, 0.08]
370
+ sigma 0.26, 94% HDI [0.22, 0.30]
371
+ """
275
372
276
373
print (f"{ self .expt_type :=^80} " )
277
374
print (f"Formula: { self .formula } " )
@@ -307,6 +404,17 @@ class InterruptedTimeSeries(PrePostFit):
307
404
... formula="y ~ 1 + t + C(month)",
308
405
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
309
406
... )
407
+ Auto-assigning NUTS sampler...
408
+ Initializing NUTS using jitter+adapt_diag...
409
+ Multiprocess sampling (4 chains in 4 jobs)
410
+ NUTS: [beta, sigma]
411
+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
412
+ (4_000 + 4_000 draws total) took 3 seconds.
413
+ Sampling: [beta, sigma, y_hat]
414
+ Sampling: [y_hat]
415
+ Sampling: [y_hat]
416
+ Sampling: [y_hat]
417
+ Sampling: [y_hat]
310
418
"""
311
419
312
420
expt_type = "Interrupted Time Series"
@@ -337,6 +445,17 @@ class SyntheticControl(PrePostFit):
337
445
... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
338
446
... ),
339
447
... )
448
+ Auto-assigning NUTS sampler...
449
+ Initializing NUTS using jitter+adapt_diag...
450
+ Multiprocess sampling (4 chains in 4 jobs)
451
+ NUTS: [beta, sigma]
452
+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
453
+ (4_000 + 4_000 draws total) took 11 seconds.
454
+ Sampling: [beta, sigma, y_hat]
455
+ Sampling: [y_hat]
456
+ Sampling: [y_hat]
457
+ Sampling: [y_hat]
458
+ Sampling: [y_hat]
340
459
"""
341
460
342
461
expt_type = "Synthetic Control"
@@ -382,7 +501,17 @@ class DifferenceInDifferences(ExperimentalDesign):
382
501
... group_variable_name="group",
383
502
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
384
503
... )
385
-
504
+ Auto-assigning NUTS sampler...
505
+ Initializing NUTS using jitter+adapt_diag...
506
+ Multiprocess sampling (4 chains in 4 jobs)
507
+ NUTS: [beta, sigma]
508
+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
509
+ (4_000 + 4_000 draws total) took 1 seconds.
510
+ Sampling: [beta, sigma, y_hat]
511
+ Sampling: [y_hat]
512
+ Sampling: [y_hat]
513
+ Sampling: [y_hat]
514
+ Sampling: [y_hat]
386
515
"""
387
516
388
517
def __init__ (
@@ -503,6 +632,12 @@ def _input_validation(self):
503
632
def plot (self ):
504
633
"""Plot the results.
505
634
Creating the combined mean + HDI legend entries is a bit involved.
635
+
636
+ Example
637
+ --------
638
+ Assuming `result` is the result of a DiD experiment:
639
+
640
+ >>> result.plot()
506
641
"""
507
642
fig , ax = plt .subplots ()
508
643
@@ -639,7 +774,25 @@ def _causal_impact_summary_stat(self) -> str:
639
774
return f"Causal impact = { causal_impact + ci } "
640
775
641
776
def summary (self ) -> None :
642
- """Print text output summarising the results"""
777
+ """
778
+ Print text output summarising the results
779
+
780
+ Example
781
+ --------
782
+ Assuming `result` is a DiD experiment
783
+
784
+ >>> result.summary()
785
+ ==========================Difference in Differences=========================
786
+ Formula: y ~ 1 + group*post_treatment
787
+ Results:
788
+ Causal impact = 0.51, $CI_{94%}$[0.41, 0.61]
789
+ Model coefficients:
790
+ Intercept 1.08, 94% HDI [1.03, 1.13]
791
+ post_treatment[T.True] 0.98, 94% HDI [0.91, 1.06]
792
+ group 0.16, 94% HDI [0.09, 0.23]
793
+ group:post_treatment[T.True] 0.51, 94% HDI [0.41, 0.61]
794
+ sigma 0.08, 94% HDI [0.07, 0.10]
795
+ """
643
796
644
797
print (f"{ self .expt_type :=^80} " )
645
798
print (f"Formula: { self .formula } " )
@@ -680,7 +833,17 @@ class RegressionDiscontinuity(ExperimentalDesign):
680
833
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
681
834
... treatment_threshold=0.5,
682
835
... )
683
-
836
+ Auto-assigning NUTS sampler...
837
+ Initializing NUTS using jitter+adapt_diag...
838
+ Multiprocess sampling (4 chains in 4 jobs)
839
+ NUTS: [beta, sigma]
840
+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
841
+ (4_000 + 4_000 draws total) took 2 seconds.
842
+ Sampling: [beta, sigma, y_hat]
843
+ Sampling: [y_hat]
844
+ Sampling: [y_hat]
845
+ Sampling: [y_hat]
846
+ Sampling: [y_hat]
684
847
"""
685
848
686
849
def __init__ (
@@ -791,7 +954,13 @@ def _is_treated(self, x):
791
954
return np .greater_equal (x , self .treatment_threshold )
792
955
793
956
def plot (self ):
794
- """Plot the results"""
957
+ """
958
+ Plot the results
959
+
960
+ Example
961
+ --------
962
+ >>> result.plot()
963
+ """
795
964
fig , ax = plt .subplots ()
796
965
# Plot raw data
797
966
sns .scatterplot (
@@ -837,7 +1006,25 @@ def plot(self):
837
1006
return (fig , ax )
838
1007
839
1008
def summary (self ) -> None :
840
- """Print text output summarising the results"""
1009
+ """
1010
+ Print text output summarising the results
1011
+
1012
+ Example
1013
+ --------
1014
+ >>> result.summary()
1015
+ ============================Regression Discontinuity==========================
1016
+ Formula: y ~ 1 + x + treated + x:treated
1017
+ Running variable: x
1018
+ Threshold on running variable: 0.5
1019
+ Results:
1020
+ Discontinuity at threshold = 0.92
1021
+ Model coefficients:
1022
+ Intercept 0.09, 94% HDI [0.00, 0.17]
1023
+ treated[T.True] 2.48, 94% HDI [1.66, 3.27]
1024
+ x 1.32, 94% HDI [1.14, 1.50]
1025
+ x:treated[T.True] -3.12, 94% HDI [-4.17, -2.05]
1026
+ sigma 0.35, 94% HDI [0.31, 0.41]
1027
+ """
841
1028
842
1029
print (f"{ self .expt_type :=^80} " )
843
1030
print (f"Formula: { self .formula } " )
@@ -876,7 +1063,16 @@ class PrePostNEGD(ExperimentalDesign):
876
1063
... pretreatment_variable_name="pre",
877
1064
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
878
1065
... )
879
-
1066
+ Auto-assigning NUTS sampler...
1067
+ Initializing NUTS using jitter+adapt_diag...
1068
+ Multiprocess sampling (4 chains in 4 jobs)
1069
+ NUTS: [beta, sigma]
1070
+ Sampling 4 chains for 1_000 tune and 1_000 draw iterations
1071
+ (4_000 + 4_000 draws total) took 3 seconds.
1072
+ Sampling: [beta, sigma, y_hat]
1073
+ Sampling: [y_hat]
1074
+ Sampling: [y_hat]
1075
+ Sampling: [y_hat]
880
1076
"""
881
1077
882
1078
def __init__ (
@@ -1010,7 +1206,23 @@ def _causal_impact_summary_stat(self) -> str:
1010
1206
return f"Causal impact = { causal_impact + ci } "
1011
1207
1012
1208
def summary (self ) -> None :
1013
- """Print text output summarising the results"""
1209
+ """
1210
+ Print text output summarising the results
1211
+
1212
+ Example
1213
+ --------
1214
+ >>> result.summary()
1215
+ =================Pretest/posttest Nonequivalent Group Design================
1216
+ Formula: post ~ 1 + C(group) + pre
1217
+ Results:
1218
+ Causal impact = 1.89, $CI_{94%}$[1.70, 2.07]
1219
+ Model coefficients:
1220
+ Intercept -0.46, 94% HDI [-1.17, 0.22]
1221
+ C(group)[T.1] 1.89, 94% HDI [1.70, 2.07]
1222
+ pre 1.05, 94% HDI [0.98, 1.12]
1223
+ sigma 0.51, 94% HDI [0.46, 0.56]
1224
+
1225
+ """
1014
1226
1015
1227
print (f"{ self .expt_type :=^80} " )
1016
1228
print (f"Formula: { self .formula } " )
0 commit comments