Skip to content

Commit e882ad2

Browse files
committed
fix incorrect calculation of counterfactual for DiD
1 parent 55247c9 commit e882ad2

File tree

3 files changed

+177
-99
lines changed

3 files changed

+177
-99
lines changed

causalpy/pymc_experiments.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def __init__(
340340
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
341341
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
342342

343-
# predicted outcome for counterfactual
343+
# predicted outcome for counterfactual. This is given by removing the influence
344+
# of the interaction term between the group and the post_treatment variable
344345
self.x_pred_counterfactual = (
345346
self.data
346347
# just the treated group
@@ -349,24 +350,28 @@ def __init__(
349350
.query("post_treatment == True")
350351
# drop the outcome variable
351352
.drop(self.outcome_variable_name, axis=1)
352-
# DO AN INTERVENTION. Set the post_treatment variable to False
353-
.assign(post_treatment=False)
354353
# We may have multiple units per time point, we only want one time point
355354
.groupby(self.time_variable_name)
356355
.first()
357356
.reset_index()
358357
)
359358
assert not self.x_pred_counterfactual.empty
360359
(new_x,) = build_design_matrices(
361-
[self._x_design_info], self.x_pred_counterfactual
360+
[self._x_design_info], self.x_pred_counterfactual, return_type="dataframe"
362361
)
362+
# INTERVENTION: set the interaction term between the group and the
363+
# post_treatment variable to zero. This is the counterfactual.
364+
for i, label in enumerate(self.labels):
365+
if "post_treatment" in label and self.group_variable_name in label:
366+
new_x.iloc[:, i] = 0
363367
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
364368

365-
# calculate causal impact
366-
self.causal_impact = (
367-
self.y_pred_treatment["posterior_predictive"].mu.isel({"obs_ind": 1})
368-
- self.y_pred_counterfactual["posterior_predictive"].mu.squeeze()
369-
)
369+
# calculate causal impact.
370+
# This is the coefficient on the interaction term
371+
coeff_names = self.idata.posterior.coords["coeffs"].data
372+
for i, label in enumerate(coeff_names):
373+
if "post_treatment" in label and self.group_variable_name in label:
374+
self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
370375

371376
def plot(self):
372377
"""Plot the results.

docs/notebooks/did_pymc.ipynb

Lines changed: 12 additions & 17 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 151 additions & 73 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)