Skip to content

Commit e3bc8cd

Browse files
committed
#76 stop evaluating for multiple units per time point
1 parent 0ed7240 commit e3bc8cd

File tree

3 files changed

+106
-79
lines changed

3 files changed

+106
-79
lines changed

causalpy/pymc_experiments.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,13 @@ def __init__(
298298
self.x_pred_control = (
299299
self.data
300300
# just the untreated group
301-
.query(f"{self.group_variable_name} == @self.untreated") # 🔥
301+
.query(f"{self.group_variable_name} == @self.untreated")
302302
# drop the outcome variable
303303
.drop(self.outcome_variable_name, axis=1)
304+
# We may have multiple units per time point, we only want one time point
305+
.groupby(self.time_variable_name)
306+
.first()
307+
.reset_index()
304308
)
305309
assert not self.x_pred_control.empty
306310
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
@@ -310,9 +314,13 @@ def __init__(
310314
self.x_pred_treatment = (
311315
self.data
312316
# just the treated group
313-
.query(f"{self.group_variable_name} == @self.treated") # 🔥
317+
.query(f"{self.group_variable_name} == @self.treated")
314318
# drop the outcome variable
315319
.drop(self.outcome_variable_name, axis=1)
320+
# We may have multiple units per time point, we only want one time point
321+
.groupby(self.time_variable_name)
322+
.first()
323+
.reset_index()
316324
)
317325
assert not self.x_pred_treatment.empty
318326
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
@@ -322,14 +330,17 @@ def __init__(
322330
self.x_pred_counterfactual = (
323331
self.data
324332
# just the treated group
325-
.query(f"{self.group_variable_name} == @self.treated") # 🔥
333+
.query(f"{self.group_variable_name} == @self.treated")
326334
# just the treatment period(s)
327-
# TODO: the line below might need some work to be more robust
328335
.query("post_treatment == True")
329336
# drop the outcome variable
330337
.drop(self.outcome_variable_name, axis=1)
331338
# DO AN INTERVENTION. Set the post_treatment variable to False
332339
.assign(post_treatment=False)
340+
# We may have multiple units per time point, we only want one time point
341+
.groupby(self.time_variable_name)
342+
.first()
343+
.reset_index()
333344
)
334345
assert not self.x_pred_counterfactual.empty
335346
(new_x,) = build_design_matrices(

docs/notebooks/did_pymc.ipynb

Lines changed: 21 additions & 23 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 70 additions & 52 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)