9
9
from causalpy .plot_utils import plot_xY
10
10
11
11
LEGEND_FONT_SIZE = 12
12
+ az .style .use ("arviz-darkgrid" )
12
13
13
14
14
15
class ExperimentalDesign :
15
16
"""Base class"""
16
17
17
18
prediction_model = None
19
+ expt_type = None
18
20
19
21
def __init__ (self , prediction_model = None , ** kwargs ):
20
22
if prediction_model is not None :
21
23
self .prediction_model = prediction_model
22
24
if self .prediction_model is None :
23
25
raise ValueError ("fitting_model not set or passed." )
24
26
27
+ def print_coefficients (self ):
28
+ """Prints the model coefficients"""
29
+ print ("Model coefficients:" )
30
+ coeffs = az .extract (self .prediction_model .idata .posterior , var_names = "beta" )
31
+ # Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
32
+ for name in self .labels :
33
+ coeff_samples = coeffs .sel (coeffs = name )
34
+ print (
35
+ f" { name : <30} { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
36
+ )
37
+ # add coeff for measurement std
38
+ coeff_samples = az .extract (
39
+ self .prediction_model .idata .posterior , var_names = "sigma"
40
+ )
41
+ name = "sigma"
42
+ print (
43
+ f" { name : <30} { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
44
+ )
45
+
25
46
26
47
class TimeSeriesExperiment (ExperimentalDesign ):
27
48
"""A class to analyse time series quasi-experiments"""
@@ -44,6 +65,7 @@ def __init__(
44
65
45
66
# set things up with pre-intervention data
46
67
y , X = dmatrices (formula , self .datapre )
68
+ self .outcome_variable_name = y .design_info .column_names [0 ]
47
69
self ._y_design_info = y .design_info
48
70
self ._x_design_info = X .design_info
49
71
self .labels = X .design_info .column_names
@@ -144,10 +166,20 @@ def plot(self):
144
166
145
167
return (fig , ax )
146
168
169
+ def summary (self ):
170
+ """Print text output summarising the results"""
171
+
172
+ print (f"{ self .expt_type :=^80} " )
173
+ print (f"Formula: { self .formula } " )
174
+ # TODO: extra experiment specific outputs here
175
+ self .print_coefficients ()
176
+
147
177
148
178
class SyntheticControl (TimeSeriesExperiment ):
149
179
"""A wrapper around the TimeSeriesExperiment class"""
150
180
181
+ expt_type = "Synthetic Control"
182
+
151
183
def plot (self ):
152
184
"""Plot the results"""
153
185
fig , ax = super ().plot ()
@@ -160,7 +192,7 @@ def plot(self):
160
192
class InterruptedTimeSeries (TimeSeriesExperiment ):
161
193
"""A wrapper around the TimeSeriesExperiment class"""
162
194
163
- pass
195
+ expt_type = "Interrupted Time Series"
164
196
165
197
166
198
class DifferenceInDifferences (ExperimentalDesign ):
@@ -177,20 +209,20 @@ def __init__(
177
209
data ,
178
210
formula ,
179
211
time_variable_name = "t" ,
180
- outcome_variable_name = "y" ,
181
212
prediction_model = None ,
182
213
** kwargs ,
183
214
):
184
215
super ().__init__ (prediction_model = prediction_model , ** kwargs )
185
216
self .data = data
217
+ self .expt_type = "Difference in Differences"
186
218
self .formula = formula
187
219
self .time_variable_name = time_variable_name
188
- self .outcome_variable_name = outcome_variable_name
189
220
y , X = dmatrices (formula , self .data )
190
221
self ._y_design_info = y .design_info
191
222
self ._x_design_info = X .design_info
192
223
self .labels = X .design_info .column_names
193
224
self .y , self .X = np .asarray (y ), np .asarray (X )
225
+ self .outcome_variable_name = y .design_info .column_names [0 ]
194
226
195
227
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
196
228
@@ -224,14 +256,18 @@ def __init__(
224
256
self .y_pred_counterfactual = self .prediction_model .predict (np .asarray (new_x ))
225
257
226
258
# calculate causal impact
227
- # TODO: This should most likely be posterior estimate, not posterior predictive
228
259
self .causal_impact = (
229
- self .y_pred_treatment ["posterior_predictive" ]
230
- .y_hat .isel ({"obs_ind" : 1 })
231
- .mean ()
232
- .data
233
- - self .y_pred_counterfactual ["posterior_predictive" ].y_hat .mean ().data
260
+ self .y_pred_treatment ["posterior_predictive" ].mu .isel ({"obs_ind" : 1 })
261
+ - self .y_pred_counterfactual ["posterior_predictive" ].mu .squeeze ()
234
262
)
263
+ # self.causal_impact = (
264
+ # self.y_pred_treatment["posterior_predictive"]
265
+ # .mu.isel({"obs_ind": 1})
266
+ # .stack(samples=["chain", "draw"])
267
+ # - self.y_pred_counterfactual["posterior_predictive"]
268
+ # .mu.stack(samples=["chain", "draw"])
269
+ # .squeeze()
270
+ # )
235
271
236
272
def plot (self ):
237
273
"""Plot the results"""
@@ -251,7 +287,7 @@ def plot(self):
251
287
# Plot model fit to control group
252
288
parts = ax .violinplot (
253
289
az .extract (
254
- self .y_pred_control , group = "posterior_predictive" , var_names = "y_hat "
290
+ self .y_pred_control , group = "posterior_predictive" , var_names = "mu "
255
291
).values .T ,
256
292
positions = self .x_pred_control [self .time_variable_name ].values ,
257
293
showmeans = False ,
@@ -266,7 +302,7 @@ def plot(self):
266
302
# Plot model fit to treatment group
267
303
parts = ax .violinplot (
268
304
az .extract (
269
- self .y_pred_treatment , group = "posterior_predictive" , var_names = "y_hat "
305
+ self .y_pred_treatment , group = "posterior_predictive" , var_names = "mu "
270
306
).values .T ,
271
307
positions = self .x_pred_treatment [self .time_variable_name ].values ,
272
308
showmeans = False ,
@@ -278,7 +314,7 @@ def plot(self):
278
314
az .extract (
279
315
self .y_pred_counterfactual ,
280
316
group = "posterior_predictive" ,
281
- var_names = "y_hat " ,
317
+ var_names = "mu " ,
282
318
).values .T ,
283
319
positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
284
320
showmeans = False ,
@@ -288,12 +324,12 @@ def plot(self):
288
324
# arrow to label the causal impact
289
325
y_pred_treatment = (
290
326
self .y_pred_treatment ["posterior_predictive" ]
291
- .y_hat .isel ({"obs_ind" : 1 })
327
+ .mu .isel ({"obs_ind" : 1 })
292
328
.mean ()
293
329
.data
294
330
)
295
331
y_pred_counterfactual = (
296
- self .y_pred_counterfactual ["posterior_predictive" ].y_hat .mean ().data
332
+ self .y_pred_counterfactual ["posterior_predictive" ].mu .mean ().data
297
333
)
298
334
ax .annotate (
299
335
"" ,
@@ -317,11 +353,27 @@ def plot(self):
317
353
xlim = [- 0.15 , 1.25 ],
318
354
xticks = [0 , 1 ],
319
355
xticklabels = ["pre" , "post" ],
320
- title = f"Causal impact = { self .causal_impact :.2f } " ,
356
+ title = self ._causal_impact_summary_stat () ,
321
357
)
322
358
ax .legend (fontsize = LEGEND_FONT_SIZE )
323
359
return (fig , ax )
324
360
361
+ def _causal_impact_summary_stat (self ):
362
+ percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
363
+ ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
364
+ causal_impact = f"{ self .causal_impact .mean ():.2f} , "
365
+ return f"Causal impact = { causal_impact + ci } "
366
+
367
+ def summary (self ):
368
+ """Print text output summarising the results"""
369
+
370
+ print (f"{ self .expt_type :=^80} " )
371
+ print (f"Formula: { self .formula } " )
372
+ print ("\n Results:" )
373
+ # TODO: extra experiment specific outputs here
374
+ print (self ._causal_impact_summary_stat ())
375
+ self .print_coefficients ()
376
+
325
377
326
378
class RegressionDiscontinuity (ExperimentalDesign ):
327
379
"""
@@ -345,20 +397,20 @@ def __init__(
345
397
treatment_threshold : float ,
346
398
prediction_model = None ,
347
399
running_variable_name : str = "x" ,
348
- outcome_variable_name = "y" ,
349
400
** kwargs ,
350
401
):
351
402
super ().__init__ (prediction_model = prediction_model , ** kwargs )
403
+ self .expt_type = "Regression Discontinuity"
352
404
self .data = data
353
405
self .formula = formula
354
406
self .running_variable_name = running_variable_name
355
- self .outcome_variable_name = outcome_variable_name
356
407
self .treatment_threshold = treatment_threshold
357
408
y , X = dmatrices (formula , self .data )
358
409
self ._y_design_info = y .design_info
359
410
self ._x_design_info = X .design_info
360
411
self .labels = X .design_info .column_names
361
412
self .y , self .X = np .asarray (y ), np .asarray (X )
413
+ self .outcome_variable_name = y .design_info .column_names [0 ]
362
414
363
415
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
364
416
@@ -445,25 +497,13 @@ def plot(self):
445
497
446
498
def summary (self ):
447
499
"""Print text output summarising the results"""
448
- print ("Difference in Differences experiment" )
500
+
501
+ print (f"{ self .expt_type :=^80} " )
449
502
print (f"Formula: { self .formula } " )
450
503
print (f"Running variable: { self .running_variable_name } " )
451
504
print (f"Threshold on running variable: { self .treatment_threshold } " )
452
505
print (f"\n Results:" )
453
506
print (
454
507
f"Discontinuity at threshold = { self .discontinuity_at_threshold .mean ():.2f} "
455
508
)
456
- print ("Model coefficients:" )
457
- coeffs = az .extract (self .prediction_model .idata .posterior , var_names = "beta" )
458
- for name in self .labels :
459
- coeff_samples = coeffs .sel (coeffs = name )
460
- print (
461
- f"\t { name } \t \t { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
462
- )
463
- # add coeff for measurement std
464
- coeff_samples = az .extract (
465
- self .prediction_model .idata .posterior , var_names = "sigma"
466
- )
467
- print (
468
- f"\t sigma\t \t { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
469
- )
509
+ self .print_coefficients ()
0 commit comments