Skip to content

Commit a0faff9

Browse files
committed
#76 DiD now 'works' for multiple pre/post treatment observations
1 parent 15fe0f7 commit a0faff9

File tree

2 files changed

+197
-80
lines changed

2 files changed

+197
-80
lines changed

causalpy/pymc_experiments.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,11 @@ def __init__(
265265
# Input validation ----------------------------------------------------
266266
# Check that `treated` appears in the module formula
267267
assert (
268-
"treated" in formula
269-
), "A predictor column called `treated` should be in the provided dataframe"
268+
"post_treatment" in formula
269+
), "A predictor called `post_treatment` should be in the dataframe"
270270
# Check that we have `treated` in the incoming dataframe
271271
assert (
272-
"treated" in self.data.columns
272+
"post_treatment" in self.data.columns
273273
), "Require a boolean column labelling observations which are `treated`"
274274
# Check for `unit` in the incoming dataframe.
275275
# *This is only used for plotting purposes*
@@ -289,46 +289,45 @@ def __init__(
289289
.I.e. the treated and untreated.
290290
"""
291291

292-
# TODO: `treated` is a deterministic function of group and time, so this could
293-
# be a function rather than supplied data
294-
295292
# DEVIATION FROM SKL EXPERIMENT CODE =============================
296-
# fit the model to the observed (pre-intervention) data
297293
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
298294
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
299295
# ================================================================
300296

301-
time_levels = self.data[self.time_variable_name].unique()
302-
303297
# predicted outcome for control group
304-
self.x_pred_control = pd.DataFrame(
305-
{
306-
self.group_variable_name: [self.untreated, self.untreated],
307-
self.time_variable_name: time_levels,
308-
"treated": [0, 0],
309-
}
298+
self.x_pred_control = (
299+
self.data
300+
# just the untreated group
301+
.query(f"district == '{self.untreated}'")
302+
# drop the outcome variable
303+
.drop(self.outcome_variable_name, axis=1)
310304
)
311305
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
312306
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
313307

314308
# predicted outcome for treatment group
315-
self.x_pred_treatment = pd.DataFrame(
316-
{
317-
self.group_variable_name: [self.treated, self.treated],
318-
self.time_variable_name: time_levels,
319-
"treated": [0, 1],
320-
}
309+
self.x_pred_treatment = (
310+
self.data
311+
# just the treated group
312+
.query(f"district == '{self.treated}'")
313+
# drop the outcome variable
314+
.drop(self.outcome_variable_name, axis=1)
321315
)
322316
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
323317
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
324318

325319
# predicted outcome for counterfactual
326-
self.x_pred_counterfactual = pd.DataFrame(
327-
{
328-
self.group_variable_name: [self.treated],
329-
self.time_variable_name: time_levels[1],
330-
"treated": [0],
331-
}
320+
self.x_pred_counterfactual = (
321+
self.data
322+
# just the treated group
323+
.query(f"district == '{self.treated}'")
324+
# just the treatment period(s)
325+
# TODO: the line below might need some work to be more robust
326+
.query("post_treatment == True")
327+
# drop the outcome variable
328+
.drop(self.outcome_variable_name, axis=1)
329+
# DO AN INTERVENTION. Set the post_treatment variable to False
330+
.assign(post_treatment=False)
332331
)
333332
(new_x,) = build_design_matrices(
334333
[self._x_design_info], self.x_pred_counterfactual
@@ -340,14 +339,6 @@ def __init__(
340339
self.y_pred_treatment["posterior_predictive"].mu.isel({"obs_ind": 1})
341340
- self.y_pred_counterfactual["posterior_predictive"].mu.squeeze()
342341
)
343-
# self.causal_impact = (
344-
# self.y_pred_treatment["posterior_predictive"]
345-
# .mu.isel({"obs_ind": 1})
346-
# .stack(samples=["chain", "draw"])
347-
# - self.y_pred_counterfactual["posterior_predictive"]
348-
# .mu.stack(samples=["chain", "draw"])
349-
# .squeeze()
350-
# )
351342

352343
def plot(self):
353344
"""Plot the results"""

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 171 additions & 45 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)