20
20
import arviz as az
21
21
import numpy as np
22
22
import pandas as pd
23
+ import xarray as xr
23
24
from matplotlib import pyplot as plt
24
- from patsy import build_design_matrices , dmatrices
25
25
from sklearn .base import RegressorMixin
26
26
27
27
from causalpy .custom_exceptions import BadIndexException
@@ -41,8 +41,10 @@ class SyntheticControl(BaseExperiment):
41
41
A pandas dataframe
42
42
:param treatment_time:
43
43
The time when treatment occurred, should be in reference to the data index
44
- :param formula:
45
- A statistical model formula
44
+ :param control_units:
45
+ A list of control units to be used in the experiment
46
+ :param treated_units:
47
+ A list of treated units to be used in the experiment
46
48
:param model:
47
49
A PyMC model
48
50
@@ -55,7 +57,8 @@ class SyntheticControl(BaseExperiment):
55
57
>>> result = cp.SyntheticControl(
56
58
... df,
57
59
... treatment_time,
58
- ... formula="actual ~ 0 + a + b + c + d + e + f + g",
60
+ ... control_units=["a", "b", "c", "d", "e", "f", "g"],
61
+ ... treated_units=["actual"],
59
62
... model=cp.pymc_models.WeightedSumFitter(
60
63
... sample_kwargs={
61
64
... "target_accept": 0.95,
@@ -66,63 +69,111 @@ class SyntheticControl(BaseExperiment):
66
69
... )
67
70
"""
68
71
69
- expt_type = "SyntheticControl"
70
72
supports_ols = True
71
73
supports_bayes = True
72
74
73
75
def __init__ (
74
76
self ,
75
77
data : pd .DataFrame ,
76
78
treatment_time : Union [int , float , pd .Timestamp ],
77
- formula : str ,
79
+ control_units : list [str ],
80
+ treated_units : list [str ],
78
81
model = None ,
79
82
** kwargs ,
80
83
) -> None :
81
84
super ().__init__ (model = model )
82
85
self .input_validation (data , treatment_time )
83
86
self .treatment_time = treatment_time
84
- # set experiment type - usually done in subclasses
85
- self .expt_type = "Pre-Post Fit"
87
+ self .control_units = control_units
88
+ self .treated_units = treated_units
89
+ self .expt_type = "SyntheticControl"
86
90
# split data in to pre and post intervention
87
91
self .datapre = data [data .index < self .treatment_time ]
88
92
self .datapost = data [data .index >= self .treatment_time ]
89
93
90
- self .formula = formula
91
-
92
- # set things up with pre-intervention data
93
- y , X = dmatrices (formula , self .datapre )
94
- self .outcome_variable_name = y .design_info .column_names [0 ]
95
- self ._y_design_info = y .design_info
96
- self ._x_design_info = X .design_info
97
- self .labels = X .design_info .column_names
98
- self .pre_y , self .pre_X = np .asarray (y ), np .asarray (X )
99
- # process post-intervention data
100
- (new_y , new_x ) = build_design_matrices (
101
- [self ._y_design_info , self ._x_design_info ], self .datapost
94
+ # split data into the 4 quadrants (pre/post, control/treated) and store as xarray dataarray
95
+ # self.datapre_control = self.datapre[self.control_units]
96
+ # self.datapre_treated = self.datapre[self.treated_units]
97
+ # self.datapost_control = self.datapost[self.control_units]
98
+ # self.datapost_treated = self.datapost[self.treated_units]
99
+ self .datapre_control = xr .DataArray (
100
+ self .datapre [self .control_units ],
101
+ dims = ["obs_ind" , "control_units" ],
102
+ coords = {
103
+ "obs_ind" : self .datapre [self .control_units ].index ,
104
+ "control_units" : self .control_units ,
105
+ },
106
+ )
107
+ self .datapre_treated = xr .DataArray (
108
+ self .datapre [self .treated_units ],
109
+ dims = ["obs_ind" , "treated_units" ],
110
+ coords = {
111
+ "obs_ind" : self .datapre [self .treated_units ].index ,
112
+ "treated_units" : self .treated_units ,
113
+ },
114
+ )
115
+ self .datapost_control = xr .DataArray (
116
+ self .datapost [self .control_units ],
117
+ dims = ["obs_ind" , "control_units" ],
118
+ coords = {
119
+ "obs_ind" : self .datapost [self .control_units ].index ,
120
+ "control_units" : self .control_units ,
121
+ },
122
+ )
123
+ self .datapost_treated = xr .DataArray (
124
+ self .datapost [self .treated_units ],
125
+ dims = ["obs_ind" , "treated_units" ],
126
+ coords = {
127
+ "obs_ind" : self .datapost [self .treated_units ].index ,
128
+ "treated_units" : self .treated_units ,
129
+ },
102
130
)
103
- self .post_X = np .asarray (new_x )
104
- self .post_y = np .asarray (new_y )
105
131
106
132
# fit the model to the observed (pre-intervention) data
107
133
if isinstance (self .model , PyMCModel ):
108
- COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .pre_X .shape [0 ])}
109
- self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
134
+ COORDS = {
135
+ "control_units" : self .control_units ,
136
+ "treated_units" : self .treated_units ,
137
+ "obs_indx" : np .arange (self .datapre .shape [0 ]),
138
+ }
139
+ self .model .fit (
140
+ X = self .datapre_control .to_numpy (),
141
+ y = self .datapre_treated .to_numpy (),
142
+ coords = COORDS ,
143
+ )
110
144
elif isinstance (self .model , RegressorMixin ):
111
- self .model .fit (X = self .pre_X , y = self .pre_y )
145
+ self .model .fit (
146
+ X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
147
+ )
112
148
else :
113
149
raise ValueError ("Model type not recognized" )
114
150
115
151
# score the goodness of fit to the pre-intervention data
116
- self .score = self .model .score (X = self .pre_X , y = self .pre_y )
152
+ self .score = self .model .score (
153
+ X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
154
+ )
117
155
118
156
# get the model predictions of the observed (pre-intervention) data
119
- self .pre_pred = self .model .predict (X = self .pre_X )
157
+ self .pre_pred = self .model .predict (X = self .datapre_control . to_numpy () )
120
158
121
159
# calculate the counterfactual
122
- self .post_pred = self .model .predict (X = self .post_X )
123
- self .pre_impact = self .model .calculate_impact (self .pre_y [:, 0 ], self .pre_pred )
160
+ self .post_pred = self .model .predict (X = self .datapost_control .to_numpy ())
161
+ # TODO: Remove the need for this 'hack' by properly updating the coords when we
162
+ # run model.predict
163
+ # TEMPORARY HACK: --------------------------------------------------------------
164
+ # : set the coords (obs_ind) for self.post_pred to be the same as the datapost
165
+ # index. This is needed for xarray to properly do the comparison (-) between
166
+ # datapre_treated and self.post_pred
167
+ # self.post_pred["posterior_predictive"] = self.post_pred[
168
+ # "posterior_predictive"
169
+ # ].assign_coords(obs_ind=self.datapost.index)
170
+ # ------------------------------------------------------------------------------
171
+ self .pre_impact = self .model .calculate_impact (
172
+ self .datapre_treated , self .pre_pred
173
+ )
174
+
124
175
self .post_impact = self .model .calculate_impact (
125
- self .post_y [:, 0 ] , self .post_pred
176
+ self .datapost_treated , self .post_pred
126
177
)
127
178
self .post_impact_cumulative = self .model .calculate_cumulative_impact (
128
179
self .post_impact
@@ -150,7 +201,11 @@ def summary(self, round_to=None) -> None:
150
201
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers
151
202
"""
152
203
print (f"{ self .expt_type :=^80} " )
153
- print (f"Formula: { self .formula } " )
204
+ print (f"Control units: { self .control_units } " )
205
+ if len (self .treated_units ) > 1 :
206
+ print (f"Treated units: { self .treated_units } " )
207
+ else :
208
+ print (f"Treated unit: { self .treated_units [0 ]} " )
154
209
self .print_coefficients (round_to )
155
210
156
211
def _bayesian_plot (
@@ -176,7 +231,9 @@ def _bayesian_plot(
176
231
handles = [(h_line , h_patch )]
177
232
labels = ["Pre-intervention period" ]
178
233
179
- (h ,) = ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
234
+ (h ,) = ax [0 ].plot (
235
+ self .datapre .index , self .datapre_treated , "k." , label = "Observations"
236
+ )
180
237
handles .append (h )
181
238
labels .append ("Observations" )
182
239
@@ -190,14 +247,14 @@ def _bayesian_plot(
190
247
handles .append ((h_line , h_patch ))
191
248
labels .append (counterfactual_label )
192
249
193
- ax [0 ].plot (self .datapost .index , self .post_y , "k." )
250
+ ax [0 ].plot (self .datapost .index , self .datapost_treated , "k." )
194
251
# Shaded causal effect
195
252
h = ax [0 ].fill_between (
196
253
self .datapost .index ,
197
254
y1 = az .extract (
198
255
self .post_pred , group = "posterior_predictive" , var_names = "mu"
199
256
).mean ("sample" ),
200
- y2 = np .squeeze (self .post_y ),
257
+ y2 = np .squeeze (self .datapost_treated ),
201
258
color = "C0" ,
202
259
alpha = 0.25 ,
203
260
)
@@ -214,20 +271,20 @@ def _bayesian_plot(
214
271
# MIDDLE PLOT -----------------------------------------------
215
272
plot_xY (
216
273
self .datapre .index ,
217
- self .pre_impact ,
274
+ self .pre_impact . sel ( treated_units = "actual" ) ,
218
275
ax = ax [1 ],
219
276
plot_hdi_kwargs = {"color" : "C0" },
220
277
)
221
278
plot_xY (
222
279
self .datapost .index ,
223
- self .post_impact ,
280
+ self .post_impact . sel ( treated_units = "actual" ) ,
224
281
ax = ax [1 ],
225
282
plot_hdi_kwargs = {"color" : "C1" },
226
283
)
227
284
ax [1 ].axhline (y = 0 , c = "k" )
228
285
ax [1 ].fill_between (
229
286
self .datapost .index ,
230
- y1 = self .post_impact .mean (["chain" , "draw" ]),
287
+ y1 = self .post_impact .mean (["chain" , "draw" ]). sel ( treated_units = "actual" ) ,
231
288
color = "C0" ,
232
289
alpha = 0.25 ,
233
290
label = "Causal impact" ,
@@ -238,7 +295,7 @@ def _bayesian_plot(
238
295
ax [2 ].set (title = "Cumulative Causal Impact" )
239
296
plot_xY (
240
297
self .datapost .index ,
241
- self .post_impact_cumulative ,
298
+ self .post_impact_cumulative . sel ( treated_units = "actual" ) ,
242
299
ax = ax [2 ],
243
300
plot_hdi_kwargs = {"color" : "C1" },
244
301
)
@@ -259,15 +316,22 @@ def _bayesian_plot(
259
316
fontsize = LEGEND_FONT_SIZE ,
260
317
)
261
318
262
- # code above: same as `PrePostFit._bayesian_plot` -------------------------------
263
- # code below: additional for the synthetic control experiment ------------------
264
-
265
319
plot_predictors = kwargs .get ("plot_predictors" , False )
266
320
if plot_predictors :
267
321
# plot control units as well
268
- ax [0 ].plot (self .datapre .index , self .pre_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1 )
269
322
ax [0 ].plot (
270
- self .datapost .index , self .post_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1
323
+ self .datapre .index ,
324
+ self .datapre_control ,
325
+ "-" ,
326
+ c = [0.8 , 0.8 , 0.8 ],
327
+ zorder = 1 ,
328
+ )
329
+ ax [0 ].plot (
330
+ self .datapost .index ,
331
+ self .datapost_control ,
332
+ "-" ,
333
+ c = [0.8 , 0.8 , 0.8 ],
334
+ zorder = 1 ,
271
335
)
272
336
273
337
return fig , ax
0 commit comments