Skip to content

Commit 0ab2fcf

Browse files
committed
#76 DiD tests now pass
1 parent a0faff9 commit 0ab2fcf

File tree

7 files changed

+294
-153
lines changed

7 files changed

+294
-153
lines changed

causalpy/data/did.csv

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,41 @@
1-
group,t,unit,treated,y
2-
0,0.0,0,0,1.037235444367556
3-
0,1.0,0,0,2.1803326054240513
4-
1,0.0,1,0,1.1815211596102946
5-
1,1.0,1,1,2.5731948057471734
6-
0,0.0,2,0,1.237781412485492
7-
0,1.0,2,0,2.064583683807223
8-
1,0.0,3,0,1.186896528144606
9-
1,1.0,3,1,2.7215532618312173
10-
0,0.0,4,0,1.0649519874697355
11-
0,1.0,4,0,1.9612022680643093
12-
1,0.0,5,0,1.2657299075634194
13-
1,1.0,5,1,2.5508204631468674
14-
0,0.0,6,0,0.8947560664459198
15-
0,1.0,6,0,2.227724135358723
16-
1,0.0,7,0,1.3074586207263057
17-
1,1.0,7,1,2.6021177943564844
18-
0,0.0,8,0,1.1845042721745236
19-
0,1.0,8,0,2.1371357945762255
20-
1,0.0,9,0,1.277659512523703
21-
1,1.0,9,1,2.7971363729134455
22-
0,0.0,10,0,0.948046520978673
23-
0,1.0,10,0,1.9911586181231065
24-
1,0.0,11,0,1.2956793345692803
25-
1,1.0,11,1,2.714212580309264
26-
0,0.0,12,0,1.0840699944593897
27-
0,1.0,12,0,1.9949161598698812
28-
1,0.0,13,0,1.279213688044527
29-
1,1.0,13,1,2.781563007268219
30-
0,0.0,14,0,0.9987011891791635
31-
0,1.0,14,0,1.8914366349764102
32-
1,0.0,15,0,1.2112578927664674
33-
1,1.0,15,1,2.7420363802422196
34-
0,0.0,16,0,0.993752136853551
35-
0,1.0,16,0,2.272692180324228
36-
1,0.0,17,0,1.1786513493076058
37-
1,1.0,17,1,2.69965381847017
38-
0,0.0,18,0,1.0980883419399656
39-
0,1.0,18,0,1.9685015295514094
40-
1,0.0,19,0,1.3616585803269048
41-
1,1.0,19,1,2.591156615919988
1+
group,t,unit,post_treatment,y
2+
0,0.0,0,False,0.897122432901507
3+
0,1.0,0,True,1.9612135788421983
4+
1,0.0,1,False,1.2335249009813691
5+
1,1.0,1,True,2.7527941327437286
6+
0,0.0,2,False,1.149207391077308
7+
0,1.0,2,True,1.9107194958946412
8+
1,0.0,3,False,1.2096028435304764
9+
1,1.0,3,True,2.7870530562317772
10+
0,0.0,4,False,1.0182211686591378
11+
0,1.0,4,True,2.1355782741951903
12+
1,0.0,5,False,1.2566023467285772
13+
1,1.0,5,True,2.6352164140993417
14+
0,0.0,6,False,1.1206312917156163
15+
0,1.0,6,True,2.0293786635661104
16+
1,0.0,7,False,1.2253914316635341
17+
1,1.0,7,True,2.836234979171606
18+
0,0.0,8,False,1.0937901142584816
19+
0,1.0,8,True,2.0046646527573992
20+
1,0.0,9,False,1.1311676279399658
21+
1,1.0,9,True,2.597416938762001
22+
0,0.0,10,False,1.1338268148431594
23+
0,1.0,10,True,2.0396150424632604
24+
1,0.0,11,False,1.2769574784336464
25+
1,1.0,11,True,2.7237901979669057
26+
0,0.0,12,False,1.0548219817786735
27+
0,1.0,12,True,2.0966644540989554
28+
1,0.0,13,False,1.2941834769826859
29+
1,1.0,13,True,2.828746461772019
30+
0,0.0,14,False,1.0011352011986534
31+
0,1.0,14,True,2.2367233120727237
32+
1,0.0,15,False,1.2621457689408864
33+
1,1.0,15,True,2.737756363134591
34+
0,0.0,16,False,1.0613566957247114
35+
0,1.0,16,True,2.105012700050028
36+
1,0.0,17,False,1.228130146156384
37+
1,1.0,17,True,2.6887857541638813
38+
0,0.0,18,False,1.2259823349004162
39+
0,1.0,18,True,2.097530059810398
40+
1,0.0,19,False,1.263074342393256
41+
1,1.0,19,True,2.697326984058356

causalpy/data/simulate_data.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,16 @@ def generate_did():
154154
intervention_time = 0.5
155155

156156
# local functions
157-
def outcome(t, control_intercept, treat_intercept_delta, trend, Δ, group, treated):
157+
def outcome(
158+
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
159+
):
158160
return (
159161
control_intercept
160162
+ (treat_intercept_delta * group)
161163
+ (t * trend)
162-
+ (Δ * treated * group)
164+
+ (Δ * post_treatment * group)
163165
)
164166

165-
def _is_treated(t, intervention_time, group):
166-
return (t > intervention_time) * group
167-
168167
df = pd.DataFrame(
169168
{
170169
"group": [0, 0, 1, 1] * 10,
@@ -173,7 +172,7 @@ def _is_treated(t, intervention_time, group):
173172
}
174173
)
175174

176-
df["treated"] = _is_treated(df["t"], intervention_time, df["group"])
175+
df["post_treatment"] = df["t"] > intervention_time
177176

178177
df["y"] = outcome(
179178
df["t"],
@@ -182,7 +181,7 @@ def _is_treated(t, intervention_time, group):
182181
trend,
183182
Δ,
184183
df["group"],
185-
df["treated"],
184+
df["post_treatment"],
186185
)
187186
df["y"] += rng.normal(0, 0.1, df.shape[0])
188187
return df

causalpy/pymc_experiments.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,29 +298,31 @@ def __init__(
298298
self.x_pred_control = (
299299
self.data
300300
# just the untreated group
301-
.query(f"district == '{self.untreated}'")
301+
.query(f"{self.group_variable_name} == @self.untreated") # 🔥
302302
# drop the outcome variable
303303
.drop(self.outcome_variable_name, axis=1)
304304
)
305+
assert not self.x_pred_control.empty
305306
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
306307
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
307308

308309
# predicted outcome for treatment group
309310
self.x_pred_treatment = (
310311
self.data
311312
# just the treated group
312-
.query(f"district == '{self.treated}'")
313+
.query(f"{self.group_variable_name} == @self.treated") # 🔥
313314
# drop the outcome variable
314315
.drop(self.outcome_variable_name, axis=1)
315316
)
317+
assert not self.x_pred_treatment.empty
316318
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
317319
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
318320

319321
# predicted outcome for counterfactual
320322
self.x_pred_counterfactual = (
321323
self.data
322324
# just the treated group
323-
.query(f"district == '{self.treated}'")
325+
.query(f"{self.group_variable_name} == @self.treated") # 🔥
324326
# just the treatment period(s)
325327
# TODO: the line below might need some work to be more robust
326328
.query("post_treatment == True")
@@ -329,6 +331,7 @@ def __init__(
329331
# DO AN INTERVENTION. Set the post_treatment variable to False
330332
.assign(post_treatment=False)
331333
)
334+
assert not self.x_pred_counterfactual.empty
332335
(new_x,) = build_design_matrices(
333336
[self._x_design_info], self.x_pred_counterfactual
334337
)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_did():
1111
df = cp.load_data("did")
1212
result = cp.pymc_experiments.DifferenceInDifferences(
1313
df,
14-
formula="y ~ 1 + group + t + treated:group",
14+
formula="y ~ 1 + group + t + group:post_treatment",
1515
time_variable_name="t",
1616
group_variable_name="group",
1717
treated=1,
@@ -26,6 +26,7 @@ def test_did():
2626

2727
@pytest.mark.integration
2828
def test_did_banks():
29+
treatment_time = 1930.5
2930
df = (
3031
cp.load_data("banks")
3132
.filter(items=["bib6", "bib8", "year"])
@@ -43,10 +44,10 @@ def test_did_banks():
4344
).sort_values("year")
4445
df_long["district"] = df_long["district"].astype("category")
4546
df_long["unit"] = df_long["district"]
46-
df_long["treated"] = (df_long.year >= 1931) & (df_long.district == "Sixth District")
47+
df_long["post_treatment"] = df_long.year >= treatment_time
4748
result = cp.pymc_experiments.DifferenceInDifferences(
4849
df_long[df_long.year.isin([1930, 1931])],
49-
formula="bib ~ 1 + district + year + district:treated",
50+
formula="bib ~ 1 + district + year + district:post_treatment",
5051
time_variable_name="year",
5152
group_variable_name="district",
5253
treated="Sixth District",

docs/notebooks/did_pymc.ipynb

Lines changed: 46 additions & 35 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 188 additions & 61 deletions
Large diffs are not rendered by default.

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)