Skip to content

Commit 16f925b

Browse files
authored
Merge pull request #151 from pymc-labs/remove-unecessary-kwargs
remove `treated` and `untreated` kwargs for DifferenceInDifferences
2 parents 505cfe0 + 4912871 commit 16f925b

File tree

4 files changed

+8
-32
lines changed

4 files changed

+8
-32
lines changed

causalpy/pymc_experiments.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,6 @@ def __init__(
254254
formula: str,
255255
time_variable_name: str,
256256
group_variable_name: str,
257-
treated: str,
258-
untreated: str,
259257
model=None,
260258
**kwargs,
261259
):
@@ -265,10 +263,6 @@ def __init__(
265263
self.formula = formula
266264
self.time_variable_name = time_variable_name
267265
self.group_variable_name = group_variable_name
268-
self.treated = treated # level of the group_variable_name that was treated
269-
self.untreated = (
270-
untreated # level of the group_variable_name that was untreated
271-
)
272266
y, X = dmatrices(formula, self.data)
273267
self._y_design_info = y.design_info
274268
self._x_design_info = X.design_info
@@ -277,11 +271,9 @@ def __init__(
277271
self.outcome_variable_name = y.design_info.column_names[0]
278272

279273
# Input validation ----------------------------------------------------
280-
# Check that `treated` appears in the module formula
281274
assert (
282275
"post_treatment" in formula
283276
), "A predictor called `post_treatment` should be in the dataframe"
284-
# Check that we have `treated` in the incoming dataframe
285277
assert (
286278
"post_treatment" in self.data.columns
287279
), "Require a boolean column labelling observations which are `treated`"
@@ -293,26 +285,22 @@ def __init__(
293285
Require a `unit` column to label unique units.
294286
This is used for plotting purposes
295287
"""
296-
# Check that `group_variable_name` has TWO levels, representing the
297-
# treated/untreated. But it does not matter what the actual names of
298-
# the levels are.
299-
assert (
300-
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
288+
# Check that `group_variable_name` is dummy coded. It should be 0 or 1
289+
assert not set(self.data[self.group_variable_name]).difference(
290+
set([0, 1])
301291
), f"""
302-
There must be 2 levels of the grouping variable {self.group_variable_name}
303-
.I.e. the treated and untreated.
292+
The grouping variable {self.group_variable_name} should be dummy coded.
293+
Consisting of 0's and 1's only.
304294
"""
305295

306-
# DEVIATION FROM SKL EXPERIMENT CODE =============================
307296
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
308297
self.model.fit(X=self.X, y=self.y, coords=COORDS)
309-
# ================================================================
310298

311299
# predicted outcome for control group
312300
self.x_pred_control = (
313301
self.data
314302
# just the untreated group
315-
.query(f"{self.group_variable_name} == @self.untreated")
303+
.query(f"{self.group_variable_name} == 0")
316304
# drop the outcome variable
317305
.drop(self.outcome_variable_name, axis=1)
318306
# We may have multiple units per time point, we only want one time point
@@ -328,7 +316,7 @@ def __init__(
328316
self.x_pred_treatment = (
329317
self.data
330318
# just the treated group
331-
.query(f"{self.group_variable_name} == @self.treated")
319+
.query(f"{self.group_variable_name} == 1")
332320
# drop the outcome variable
333321
.drop(self.outcome_variable_name, axis=1)
334322
# We may have multiple units per time point, we only want one time point
@@ -345,7 +333,7 @@ def __init__(
345333
self.x_pred_counterfactual = (
346334
self.data
347335
# just the treated group
348-
.query(f"{self.group_variable_name} == @self.treated")
336+
.query(f"{self.group_variable_name} == 1")
349337
# just the treatment period(s)
350338
.query("post_treatment == True")
351339
# drop the outcome variable

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ def test_did():
1414
formula="y ~ 1 + group*post_treatment",
1515
time_variable_name="t",
1616
group_variable_name="group",
17-
treated=1,
18-
untreated=0,
1917
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
2018
)
2119
assert isinstance(df, pd.DataFrame)
@@ -59,8 +57,6 @@ def test_did_banks_simple():
5957
formula="bib ~ 1 + district * post_treatment",
6058
time_variable_name="year",
6159
group_variable_name="district",
62-
treated=1,
63-
untreated=0,
6460
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
6561
)
6662
assert isinstance(df, pd.DataFrame)
@@ -100,8 +96,6 @@ def test_did_banks_multi():
10096
formula="bib ~ 1 + year + district + post_treatment + district:post_treatment",
10197
time_variable_name="year",
10298
group_variable_name="district",
103-
treated=1,
104-
untreated=0,
10599
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
106100
)
107101
assert isinstance(df, pd.DataFrame)

docs/notebooks/did_pymc.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@
239239
" formula=\"y ~ 1 + group*post_treatment\",\n",
240240
" time_variable_name=\"t\",\n",
241241
" group_variable_name=\"group\",\n",
242-
" treated=1,\n",
243-
" untreated=0,\n",
244242
" model=cp.pymc_models.LinearRegression(sample_kwargs={\"random_seed\": seed}),\n",
245243
")"
246244
]

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,6 @@
482482
" formula=\"bib ~ 1 + district * post_treatment\",\n",
483483
" time_variable_name=\"year\",\n",
484484
" group_variable_name=\"district\",\n",
485-
" treated=1,\n",
486-
" untreated=0,\n",
487485
" model=cp.pymc_models.LinearRegression(\n",
488486
" sample_kwargs={\"target_accept\": 0.95, \"random_seed\": seed}\n",
489487
" ),\n",
@@ -647,8 +645,6 @@
647645
" formula=\"bib ~ 1 + year + district + post_treatment + district:post_treatment\",\n",
648646
" time_variable_name=\"year\",\n",
649647
" group_variable_name=\"district\",\n",
650-
" treated=1,\n",
651-
" untreated=0,\n",
652648
" model=cp.pymc_models.LinearRegression(sample_kwargs={\"random_seed\": seed}),\n",
653649
")"
654650
]

0 commit comments

Comments
 (0)