1
1
"""
2
2
Experiments for Scikit-Learn models
3
+
4
+ - ExperimentalDesign: base class for skl experiments
5
+ - PrePostFit: base class for synthetic control and interrupted time series
6
+ - SyntheticControl
7
+ - InterruptedTimeSeries
8
+ - DifferenceInDifferences
9
+ - RegressionDiscontinuity
3
10
"""
4
11
import warnings
5
12
from typing import Optional
@@ -27,8 +34,33 @@ def __init__(self, model=None, **kwargs):
27
34
28
35
29
36
class PrePostFit (ExperimentalDesign ):
30
- """A class to analyse quasi-experiments where parameter estimation is based on just
31
- the pre-intervention data."""
37
+ """
38
+ A class to analyse quasi-experiments where parameter estimation is based on just
39
+ the pre-intervention data.
40
+
41
+ :param data:
42
+ A pandas data frame
43
+ :param treatment_time:
44
+ The index or time value of when treatment begins
45
+ :param formula:
46
+ A statistical model formula
47
+ :param model:
48
+ An sklearn model object
49
+
50
+ Example
51
+ --------
52
+ >>> from sklearn.linear_model import LinearRegression
53
+ >>> import causalpy as cp
54
+ >>> df = cp.load_data("sc")
55
+ >>> treatment_time = 70
56
+ >>> result = cp.skl_experiments.PrePostFit(
57
+ ... df,
58
+ ... treatment_time,
59
+ ... formula="actual ~ 0 + a + b + c + d + e + f + g",
60
+ ... model = cp.skl_models.WeightedProportion()
61
+ ... )
62
+
63
+ """
32
64
33
65
def __init__ (
34
66
self ,
@@ -144,7 +176,16 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
144
176
return (fig , ax )
145
177
146
178
def get_coeffs (self ):
147
- """Returns model coefficients"""
179
+ """
180
+ Returns model coefficients
181
+
182
+ Example
183
+ --------
184
+ >>> result.get_coeffs()
185
+ array([3.97370896e-01, 1.53881980e-01, 4.48747123e-01, 1.04639857e-16,
186
+ 0.00000000e+00, 0.00000000e+00, 2.92931287e-16])
187
+
188
+ """
148
189
return np .squeeze (self .model .coef_ )
149
190
150
191
def plot_coeffs (self ):
@@ -161,13 +202,68 @@ def plot_coeffs(self):
161
202
162
203
163
204
class InterruptedTimeSeries (PrePostFit ):
164
- """Interrupted time series analysis"""
205
+ """
206
+ Interrupted time series analysis, a wrapper around the PrePostFit class
207
+
208
+ :param data:
209
+ A pandas data frame
210
+ :param treatment_time:
211
+ The index or time value of when treatment begins
212
+ :param formula:
213
+ A statistical model formula
214
+ :param model:
215
+ An sklearn model object
216
+
217
+ Example
218
+ --------
219
+ >>> from sklearn.linear_model import LinearRegression
220
+ >>> import pandas as pd
221
+ >>> import causalpy as cp
222
+ >>> df = (
223
+ ... cp.load_data("its")
224
+ ... .assign(date=lambda x: pd.to_datetime(x["date"]))
225
+ ... .set_index("date")
226
+ ... )
227
+ >>> treatment_time = pd.to_datetime("2017-01-01")
228
+ >>> result = cp.skl_experiments.InterruptedTimeSeries(
229
+ ... df,
230
+ ... treatment_time,
231
+ ... formula="y ~ 1 + t + C(month)",
232
+ ... model = LinearRegression()
233
+ ... )
234
+
235
+ """
165
236
166
237
expt_type = "Interrupted Time Series"
167
238
168
239
169
240
class SyntheticControl (PrePostFit ):
170
- """A wrapper around the PrePostFit class"""
241
+ """
242
+ A wrapper around the PrePostFit class
243
+
244
+ :param data:
245
+ A pandas data frame
246
+ :param treatment_time:
247
+ The index or time value of when treatment begins
248
+ :param formula:
249
+ A statistical model formula
250
+ :param model:
251
+ An sklearn model object
252
+
253
+ Example
254
+ --------
255
+ >>> from sklearn.linear_model import LinearRegression
256
+ >>> import causalpy as cp
257
+ >>> df = cp.load_data("sc")
258
+ >>> treatment_time = 70
259
+ >>> result = cp.skl_experiments.SyntheticControl(
260
+ ... df,
261
+ ... treatment_time,
262
+ ... formula="actual ~ 0 + a + b + c + d + e + f + g",
263
+ ... model = cp.skl_models.WeightedProportion()
264
+ ... )
265
+
266
+ """
171
267
172
268
def plot (self , plot_predictors = False , ** kwargs ):
173
269
"""Plot the results"""
@@ -187,6 +283,32 @@ class DifferenceInDifferences(ExperimentalDesign):
187
283
188
284
There is no pre/post intervention data distinction for DiD, we fit all the data
189
285
available.
286
+
287
+ :param data:
288
+ A pandas data frame
289
+ :param formula:
290
+ A statistical model formula
291
+ :param time_variable_name:
292
+ Name of the data column for the time variable
293
+ :param group_variable_name:
294
+ Name of the data column for the group variable
295
+ :param model:
296
+ A PyMC model for difference in differences
297
+
298
+ Example
299
+ --------
300
+ >>> df = cp.load_data("did")
301
+ >>> seed = 42
302
+ >>> result = cp.skl_experiments.DifferenceInDifferences(
303
+ ... data,
304
+ ... formula="y ~ 1 + group*post_treatment",
305
+ ... time_variable_name="t",
306
+ ... group_variable_name="group",
307
+ ... treated=1,
308
+ ... untreated=0,
309
+ ... model=LinearRegression(),
310
+ ... )
311
+
190
312
"""
191
313
192
314
def __init__ (
@@ -373,6 +495,17 @@ class RegressionDiscontinuity(ExperimentalDesign):
373
495
:param bandwidth:
374
496
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
375
497
the model.
498
+
499
+ Example
500
+ --------
501
+ >>> data = cp.load_data("rd")
502
+ >>> result = cp.skl_experiments.RegressionDiscontinuity(
503
+ ... data,
504
+ ... formula="y ~ 1 + x + treated",
505
+ ... model=LinearRegression(),
506
+ ... treatment_threshold=0.5,
507
+ ... )
508
+
376
509
"""
377
510
378
511
def __init__ (
@@ -503,7 +636,24 @@ def plot(self):
503
636
return (fig , ax )
504
637
505
638
def summary (self ):
506
- """Print text output summarising the results"""
639
+ """
640
+ Print text output summarising the results
641
+
642
+ Example
643
+ --------
644
+ >>> result.summary()
645
+ Difference in Differences experiment
646
+ Formula: y ~ 1 + x + treated
647
+ Running variable: x
648
+ Threshold on running variable: 0.5
649
+ Results:
650
+ Discontinuity at threshold = 0.19
651
+ Model coefficients:
652
+ Intercept 0.0
653
+ treated[T.True] 0.19034196317793994
654
+ x 1.229600855360073
655
+
656
+ """
507
657
print ("Difference in Differences experiment" )
508
658
print (f"Formula: { self .formula } " )
509
659
print (f"Running variable: { self .running_variable_name } " )
0 commit comments