@@ -209,6 +209,9 @@ def __init__(
209
209
data : pd .DataFrame ,
210
210
formula : str ,
211
211
time_variable_name : str ,
212
+ group_variable_name : str ,
213
+ treated : str ,
214
+ untreated : str ,
212
215
prediction_model = None ,
213
216
** kwargs ,
214
217
):
@@ -217,13 +220,24 @@ def __init__(
217
220
self .expt_type = "Difference in Differences"
218
221
self .formula = formula
219
222
self .time_variable_name = time_variable_name
223
+ self .group_variable_name = group_variable_name
224
+ self .treated = treated # level of the group_variable_name that was treated
225
+ self .untreated = (
226
+ untreated # level of the group_variable_name that was untreated
227
+ )
220
228
y , X = dmatrices (formula , self .data )
221
229
self ._y_design_info = y .design_info
222
230
self ._x_design_info = X .design_info
223
231
self .labels = X .design_info .column_names
224
232
self .y , self .X = np .asarray (y ), np .asarray (X )
225
233
self .outcome_variable_name = y .design_info .column_names [0 ]
226
234
235
+ assert (
236
+ "treated" in formula
237
+ ), "A predictor column called `treated` should be in the provided dataframe"
238
+
239
+ # TODO: check that data in column self.group_variable_name has TWO levels
240
+
227
241
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
228
242
229
243
# DEVIATION FROM SKL EXPERIMENT CODE =============================
@@ -232,23 +246,37 @@ def __init__(
232
246
self .prediction_model .fit (X = self .X , y = self .y , coords = COORDS )
233
247
# ================================================================
234
248
249
+ time_levels = self .data [self .time_variable_name ].unique ()
250
+
235
251
# predicted outcome for control group
236
252
self .x_pred_control = pd .DataFrame (
237
- {"group" : [0 , 0 ], "t" : [0.0 , 1.0 ], "treated" : [0 , 0 ]}
253
+ {
254
+ self .group_variable_name : [self .untreated , self .untreated ],
255
+ self .time_variable_name : time_levels ,
256
+ "treated" : [0 , 0 ],
257
+ }
238
258
)
239
259
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_control )
240
260
self .y_pred_control = self .prediction_model .predict (np .asarray (new_x ))
241
261
242
262
# predicted outcome for treatment group
243
263
self .x_pred_treatment = pd .DataFrame (
244
- {"group" : [1 , 1 ], "t" : [0.0 , 1.0 ], "treated" : [0 , 1 ]}
264
+ {
265
+ self .group_variable_name : [self .treated , self .treated ],
266
+ self .time_variable_name : time_levels ,
267
+ "treated" : [0 , 1 ],
268
+ }
245
269
)
246
270
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
247
271
self .y_pred_treatment = self .prediction_model .predict (np .asarray (new_x ))
248
272
249
273
# predicted outcome for counterfactual
250
274
self .x_pred_counterfactual = pd .DataFrame (
251
- {"group" : [1 ], "t" : [1.0 ], "treated" : [0 ]}
275
+ {
276
+ self .group_variable_name : [self .treated ],
277
+ self .time_variable_name : time_levels [1 ],
278
+ "treated" : [0 ],
279
+ }
252
280
)
253
281
(new_x ,) = build_design_matrices (
254
282
[self ._x_design_info ], self .x_pred_counterfactual
@@ -278,7 +306,7 @@ def plot(self):
278
306
self .data ,
279
307
x = self .time_variable_name ,
280
308
y = self .outcome_variable_name ,
281
- hue = "group" ,
309
+ hue = self . group_variable_name ,
282
310
units = "unit" ,
283
311
estimator = None ,
284
312
alpha = 0.25 ,
0 commit comments