@@ -254,8 +254,6 @@ def __init__(
254
254
formula : str ,
255
255
time_variable_name : str ,
256
256
group_variable_name : str ,
257
- treated : str ,
258
- untreated : str ,
259
257
model = None ,
260
258
** kwargs ,
261
259
):
@@ -265,10 +263,6 @@ def __init__(
265
263
self .formula = formula
266
264
self .time_variable_name = time_variable_name
267
265
self .group_variable_name = group_variable_name
268
- self .treated = treated # level of the group_variable_name that was treated
269
- self .untreated = (
270
- untreated # level of the group_variable_name that was untreated
271
- )
272
266
y , X = dmatrices (formula , self .data )
273
267
self ._y_design_info = y .design_info
274
268
self ._x_design_info = X .design_info
@@ -277,11 +271,9 @@ def __init__(
277
271
self .outcome_variable_name = y .design_info .column_names [0 ]
278
272
279
273
# Input validation ----------------------------------------------------
280
- # Check that `treated` appears in the module formula
281
274
assert (
282
275
"post_treatment" in formula
283
276
), "A predictor called `post_treatment` should be in the dataframe"
284
- # Check that we have `treated` in the incoming dataframe
285
277
assert (
286
278
"post_treatment" in self .data .columns
287
279
), "Require a boolean column labelling observations which are `treated`"
@@ -293,26 +285,22 @@ def __init__(
293
285
Require a `unit` column to label unique units.
294
286
This is used for plotting purposes
295
287
"""
296
- # Check that `group_variable_name` has TWO levels, representing the
297
- # treated/untreated. But it does not matter what the actual names of
298
- # the levels are.
299
- assert (
300
- len (pd .Categorical (self .data [self .group_variable_name ]).categories ) == 2
288
+ # Check that `group_variable_name` is dummy coded. It should be 0 or 1
289
+ assert not set (self .data [self .group_variable_name ]).difference (
290
+ set ([0 , 1 ])
301
291
), f"""
302
- There must be 2 levels of the grouping variable { self .group_variable_name }
303
- .I.e. the treated and untreated .
292
+ The grouping variable { self .group_variable_name } should be dummy coded.
293
+ Consisting of 0's and 1's only .
304
294
"""
305
295
306
- # DEVIATION FROM SKL EXPERIMENT CODE =============================
307
296
COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .X .shape [0 ])}
308
297
self .model .fit (X = self .X , y = self .y , coords = COORDS )
309
- # ================================================================
310
298
311
299
# predicted outcome for control group
312
300
self .x_pred_control = (
313
301
self .data
314
302
# just the untreated group
315
- .query (f"{ self .group_variable_name } == @self.untreated " )
303
+ .query (f"{ self .group_variable_name } == 0 " )
316
304
# drop the outcome variable
317
305
.drop (self .outcome_variable_name , axis = 1 )
318
306
# We may have multiple units per time point, we only want one time point
@@ -328,7 +316,7 @@ def __init__(
328
316
self .x_pred_treatment = (
329
317
self .data
330
318
# just the treated group
331
- .query (f"{ self .group_variable_name } == @self.treated " )
319
+ .query (f"{ self .group_variable_name } == 1 " )
332
320
# drop the outcome variable
333
321
.drop (self .outcome_variable_name , axis = 1 )
334
322
# We may have multiple units per time point, we only want one time point
@@ -345,7 +333,7 @@ def __init__(
345
333
self .x_pred_counterfactual = (
346
334
self .data
347
335
# just the treated group
348
- .query (f"{ self .group_variable_name } == @self.treated " )
336
+ .query (f"{ self .group_variable_name } == 1 " )
349
337
# just the treatment period(s)
350
338
.query ("post_treatment == True" )
351
339
# drop the outcome variable
0 commit comments