Skip to content

Commit 7840611

Browse files
committed
#76 #44 DID now works generalises to custom varnames + level values
1 parent d348ee8 commit 7840611

File tree

3 files changed

+331
-34
lines changed

3 files changed

+331
-34
lines changed

causalpy/pymc_experiments.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def __init__(
209209
data: pd.DataFrame,
210210
formula: str,
211211
time_variable_name: str,
212+
group_variable_name: str,
213+
treated: str,
214+
untreated: str,
212215
prediction_model=None,
213216
**kwargs,
214217
):
@@ -217,13 +220,24 @@ def __init__(
217220
self.expt_type = "Difference in Differences"
218221
self.formula = formula
219222
self.time_variable_name = time_variable_name
223+
self.group_variable_name = group_variable_name
224+
self.treated = treated # level of the group_variable_name that was treated
225+
self.untreated = (
226+
untreated # level of the group_variable_name that was untreated
227+
)
220228
y, X = dmatrices(formula, self.data)
221229
self._y_design_info = y.design_info
222230
self._x_design_info = X.design_info
223231
self.labels = X.design_info.column_names
224232
self.y, self.X = np.asarray(y), np.asarray(X)
225233
self.outcome_variable_name = y.design_info.column_names[0]
226234

235+
assert (
236+
"treated" in formula
237+
), "A predictor column called `treated` should be in the provided dataframe"
238+
239+
# TODO: check that data in column self.group_variable_name has TWO levels
240+
227241
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
228242

229243
# DEVIATION FROM SKL EXPERIMENT CODE =============================
@@ -232,23 +246,37 @@ def __init__(
232246
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
233247
# ================================================================
234248

249+
time_levels = self.data[self.time_variable_name].unique()
250+
235251
# predicted outcome for control group
236252
self.x_pred_control = pd.DataFrame(
237-
{"group": [0, 0], "t": [0.0, 1.0], "treated": [0, 0]}
253+
{
254+
self.group_variable_name: [self.untreated, self.untreated],
255+
self.time_variable_name: time_levels,
256+
"treated": [0, 0],
257+
}
238258
)
239259
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
240260
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
241261

242262
# predicted outcome for treatment group
243263
self.x_pred_treatment = pd.DataFrame(
244-
{"group": [1, 1], "t": [0.0, 1.0], "treated": [0, 1]}
264+
{
265+
self.group_variable_name: [self.treated, self.treated],
266+
self.time_variable_name: time_levels,
267+
"treated": [0, 1],
268+
}
245269
)
246270
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
247271
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
248272

249273
# predicted outcome for counterfactual
250274
self.x_pred_counterfactual = pd.DataFrame(
251-
{"group": [1], "t": [1.0], "treated": [0]}
275+
{
276+
self.group_variable_name: [self.treated],
277+
self.time_variable_name: time_levels[1],
278+
"treated": [0],
279+
}
252280
)
253281
(new_x,) = build_design_matrices(
254282
[self._x_design_info], self.x_pred_counterfactual
@@ -278,7 +306,7 @@ def plot(self):
278306
self.data,
279307
x=self.time_variable_name,
280308
y=self.outcome_variable_name,
281-
hue="group",
309+
hue=self.group_variable_name,
282310
units="unit",
283311
estimator=None,
284312
alpha=0.25,

docs/notebooks/did_pymc.ipynb

Lines changed: 25 additions & 13 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 274 additions & 17 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)