@@ -298,9 +298,13 @@ def __init__(
298
298
self .x_pred_control = (
299
299
self .data
300
300
# just the untreated group
301
- .query (f"{ self .group_variable_name } == @self.untreated" ) # 🔥
301
+ .query (f"{ self .group_variable_name } == @self.untreated" )
302
302
# drop the outcome variable
303
303
.drop (self .outcome_variable_name , axis = 1 )
304
+ # We may have multiple units per time point, we only want one time point
305
+ .groupby (self .time_variable_name )
306
+ .first ()
307
+ .reset_index ()
304
308
)
305
309
assert not self .x_pred_control .empty
306
310
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_control )
@@ -310,9 +314,13 @@ def __init__(
310
314
self .x_pred_treatment = (
311
315
self .data
312
316
# just the treated group
313
- .query (f"{ self .group_variable_name } == @self.treated" ) # 🔥
317
+ .query (f"{ self .group_variable_name } == @self.treated" )
314
318
# drop the outcome variable
315
319
.drop (self .outcome_variable_name , axis = 1 )
320
+ # We may have multiple units per time point, we only want one time point
321
+ .groupby (self .time_variable_name )
322
+ .first ()
323
+ .reset_index ()
316
324
)
317
325
assert not self .x_pred_treatment .empty
318
326
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
@@ -322,14 +330,17 @@ def __init__(
322
330
self .x_pred_counterfactual = (
323
331
self .data
324
332
# just the treated group
325
- .query (f"{ self .group_variable_name } == @self.treated" ) # 🔥
333
+ .query (f"{ self .group_variable_name } == @self.treated" )
326
334
# just the treatment period(s)
327
- # TODO: the line below might need some work to be more robust
328
335
.query ("post_treatment == True" )
329
336
# drop the outcome variable
330
337
.drop (self .outcome_variable_name , axis = 1 )
331
338
# DO AN INTERVENTION. Set the post_treatment variable to False
332
339
.assign (post_treatment = False )
340
+ # We may have multiple units per time point, we only want one time point
341
+ .groupby (self .time_variable_name )
342
+ .first ()
343
+ .reset_index ()
333
344
)
334
345
assert not self .x_pred_counterfactual .empty
335
346
(new_x ,) = build_design_matrices (
0 commit comments