Skip to content

Commit 5aead6b

Browse files
authored
Merge pull request #140 from pymc-labs/did_multiple_observations
Did multiple observations
2 parents 15fe0f7 + 36b4511 commit 5aead6b

10 files changed

+565
-258
lines changed

causalpy/data/did.csv

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,41 @@
1-
group,t,unit,treated,y
2-
0,0.0,0,0,1.037235444367556
3-
0,1.0,0,0,2.1803326054240513
4-
1,0.0,1,0,1.1815211596102946
5-
1,1.0,1,1,2.5731948057471734
6-
0,0.0,2,0,1.237781412485492
7-
0,1.0,2,0,2.064583683807223
8-
1,0.0,3,0,1.186896528144606
9-
1,1.0,3,1,2.7215532618312173
10-
0,0.0,4,0,1.0649519874697355
11-
0,1.0,4,0,1.9612022680643093
12-
1,0.0,5,0,1.2657299075634194
13-
1,1.0,5,1,2.5508204631468674
14-
0,0.0,6,0,0.8947560664459198
15-
0,1.0,6,0,2.227724135358723
16-
1,0.0,7,0,1.3074586207263057
17-
1,1.0,7,1,2.6021177943564844
18-
0,0.0,8,0,1.1845042721745236
19-
0,1.0,8,0,2.1371357945762255
20-
1,0.0,9,0,1.277659512523703
21-
1,1.0,9,1,2.7971363729134455
22-
0,0.0,10,0,0.948046520978673
23-
0,1.0,10,0,1.9911586181231065
24-
1,0.0,11,0,1.2956793345692803
25-
1,1.0,11,1,2.714212580309264
26-
0,0.0,12,0,1.0840699944593897
27-
0,1.0,12,0,1.9949161598698812
28-
1,0.0,13,0,1.279213688044527
29-
1,1.0,13,1,2.781563007268219
30-
0,0.0,14,0,0.9987011891791635
31-
0,1.0,14,0,1.8914366349764102
32-
1,0.0,15,0,1.2112578927664674
33-
1,1.0,15,1,2.7420363802422196
34-
0,0.0,16,0,0.993752136853551
35-
0,1.0,16,0,2.272692180324228
36-
1,0.0,17,0,1.1786513493076058
37-
1,1.0,17,1,2.69965381847017
38-
0,0.0,18,0,1.0980883419399656
39-
0,1.0,18,0,1.9685015295514094
40-
1,0.0,19,0,1.3616585803269048
41-
1,1.0,19,1,2.591156615919988
1+
group,t,unit,post_treatment,y
2+
0,0.0,0,False,0.897122432901507
3+
0,1.0,0,True,1.9612135788421983
4+
1,0.0,1,False,1.2335249009813691
5+
1,1.0,1,True,2.7527941327437286
6+
0,0.0,2,False,1.149207391077308
7+
0,1.0,2,True,1.9107194958946412
8+
1,0.0,3,False,1.2096028435304764
9+
1,1.0,3,True,2.7870530562317772
10+
0,0.0,4,False,1.0182211686591378
11+
0,1.0,4,True,2.1355782741951903
12+
1,0.0,5,False,1.2566023467285772
13+
1,1.0,5,True,2.6352164140993417
14+
0,0.0,6,False,1.1206312917156163
15+
0,1.0,6,True,2.0293786635661104
16+
1,0.0,7,False,1.2253914316635341
17+
1,1.0,7,True,2.836234979171606
18+
0,0.0,8,False,1.0937901142584816
19+
0,1.0,8,True,2.0046646527573992
20+
1,0.0,9,False,1.1311676279399658
21+
1,1.0,9,True,2.597416938762001
22+
0,0.0,10,False,1.1338268148431594
23+
0,1.0,10,True,2.0396150424632604
24+
1,0.0,11,False,1.2769574784336464
25+
1,1.0,11,True,2.7237901979669057
26+
0,0.0,12,False,1.0548219817786735
27+
0,1.0,12,True,2.0966644540989554
28+
1,0.0,13,False,1.2941834769826859
29+
1,1.0,13,True,2.828746461772019
30+
0,0.0,14,False,1.0011352011986534
31+
0,1.0,14,True,2.2367233120727237
32+
1,0.0,15,False,1.2621457689408864
33+
1,1.0,15,True,2.737756363134591
34+
0,0.0,16,False,1.0613566957247114
35+
0,1.0,16,True,2.105012700050028
36+
1,0.0,17,False,1.228130146156384
37+
1,1.0,17,True,2.6887857541638813
38+
0,0.0,18,False,1.2259823349004162
39+
0,1.0,18,True,2.097530059810398
40+
1,0.0,19,False,1.263074342393256
41+
1,1.0,19,True,2.697326984058356

causalpy/data/simulate_data.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,16 @@ def generate_did():
154154
intervention_time = 0.5
155155

156156
# local functions
157-
def outcome(t, control_intercept, treat_intercept_delta, trend, Δ, group, treated):
157+
def outcome(
158+
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
159+
):
158160
return (
159161
control_intercept
160162
+ (treat_intercept_delta * group)
161163
+ (t * trend)
162-
+ (Δ * treated * group)
164+
+ (Δ * post_treatment * group)
163165
)
164166

165-
def _is_treated(t, intervention_time, group):
166-
return (t > intervention_time) * group
167-
168167
df = pd.DataFrame(
169168
{
170169
"group": [0, 0, 1, 1] * 10,
@@ -173,7 +172,7 @@ def _is_treated(t, intervention_time, group):
173172
}
174173
)
175174

176-
df["treated"] = _is_treated(df["t"], intervention_time, df["group"])
175+
df["post_treatment"] = df["t"] > intervention_time
177176

178177
df["y"] = outcome(
179178
df["t"],
@@ -182,7 +181,7 @@ def _is_treated(t, intervention_time, group):
182181
trend,
183182
Δ,
184183
df["group"],
185-
df["treated"],
184+
df["post_treatment"],
186185
)
187186
df["y"] += rng.normal(0, 0.1, df.shape[0])
188187
return df

causalpy/pymc_experiments.py

Lines changed: 84 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,11 @@ def __init__(
265265
# Input validation ----------------------------------------------------
266266
# Check that `treated` appears in the module formula
267267
assert (
268-
"treated" in formula
269-
), "A predictor column called `treated` should be in the provided dataframe"
268+
"post_treatment" in formula
269+
), "A predictor called `post_treatment` should be in the dataframe"
270270
# Check that we have `treated` in the incoming dataframe
271271
assert (
272-
"treated" in self.data.columns
272+
"post_treatment" in self.data.columns
273273
), "Require a boolean column labelling observations which are `treated`"
274274
# Check for `unit` in the incoming dataframe.
275275
# *This is only used for plotting purposes*
@@ -289,47 +289,60 @@ def __init__(
289289
.I.e. the treated and untreated.
290290
"""
291291

292-
# TODO: `treated` is a deterministic function of group and time, so this could
293-
# be a function rather than supplied data
294-
295292
# DEVIATION FROM SKL EXPERIMENT CODE =============================
296-
# fit the model to the observed (pre-intervention) data
297293
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
298294
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
299295
# ================================================================
300296

301-
time_levels = self.data[self.time_variable_name].unique()
302-
303297
# predicted outcome for control group
304-
self.x_pred_control = pd.DataFrame(
305-
{
306-
self.group_variable_name: [self.untreated, self.untreated],
307-
self.time_variable_name: time_levels,
308-
"treated": [0, 0],
309-
}
310-
)
298+
self.x_pred_control = (
299+
self.data
300+
# just the untreated group
301+
.query(f"{self.group_variable_name} == @self.untreated")
302+
# drop the outcome variable
303+
.drop(self.outcome_variable_name, axis=1)
304+
# We may have multiple units per time point, we only want one time point
305+
.groupby(self.time_variable_name)
306+
.first()
307+
.reset_index()
308+
)
309+
assert not self.x_pred_control.empty
311310
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
312311
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
313312

314313
# predicted outcome for treatment group
315-
self.x_pred_treatment = pd.DataFrame(
316-
{
317-
self.group_variable_name: [self.treated, self.treated],
318-
self.time_variable_name: time_levels,
319-
"treated": [0, 1],
320-
}
321-
)
314+
self.x_pred_treatment = (
315+
self.data
316+
# just the treated group
317+
.query(f"{self.group_variable_name} == @self.treated")
318+
# drop the outcome variable
319+
.drop(self.outcome_variable_name, axis=1)
320+
# We may have multiple units per time point, we only want one time point
321+
.groupby(self.time_variable_name)
322+
.first()
323+
.reset_index()
324+
)
325+
assert not self.x_pred_treatment.empty
322326
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
323327
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
324328

325329
# predicted outcome for counterfactual
326-
self.x_pred_counterfactual = pd.DataFrame(
327-
{
328-
self.group_variable_name: [self.treated],
329-
self.time_variable_name: time_levels[1],
330-
"treated": [0],
331-
}
332-
)
330+
self.x_pred_counterfactual = (
331+
self.data
332+
# just the treated group
333+
.query(f"{self.group_variable_name} == @self.treated")
334+
# just the treatment period(s)
335+
.query("post_treatment == True")
336+
# drop the outcome variable
337+
.drop(self.outcome_variable_name, axis=1)
338+
# DO AN INTERVENTION. Set the post_treatment variable to False
339+
.assign(post_treatment=False)
340+
# We may have multiple units per time point, we only want one time point
341+
.groupby(self.time_variable_name)
342+
.first()
343+
.reset_index()
344+
)
345+
assert not self.x_pred_counterfactual.empty
333346
(new_x,) = build_design_matrices(
334347
[self._x_design_info], self.x_pred_counterfactual
335348
)
@@ -340,14 +353,6 @@ def __init__(
340353
self.y_pred_treatment["posterior_predictive"].mu.isel({"obs_ind": 1})
341354
- self.y_pred_counterfactual["posterior_predictive"].mu.squeeze()
342355
)
343-
# self.causal_impact = (
344-
# self.y_pred_treatment["posterior_predictive"]
345-
# .mu.isel({"obs_ind": 1})
346-
# .stack(samples=["chain", "draw"])
347-
# - self.y_pred_counterfactual["posterior_predictive"]
348-
# .mu.stack(samples=["chain", "draw"])
349-
# .squeeze()
350-
# )
351356

352357
def plot(self):
353358
"""Plot the results"""
@@ -365,53 +370,52 @@ def plot(self):
365370
alpha=0.5,
366371
ax=ax,
367372
)
373+
368374
# Plot model fit to control group
369-
parts = ax.violinplot(
370-
az.extract(
371-
self.y_pred_control, group="posterior_predictive", var_names="mu"
372-
).values.T,
373-
positions=self.x_pred_control[self.time_variable_name].values,
374-
showmeans=False,
375-
showmedians=False,
376-
widths=0.2,
377-
)
378-
for pc in parts["bodies"]:
379-
pc.set_facecolor("C0")
380-
pc.set_edgecolor("None")
381-
pc.set_alpha(0.5)
375+
time_points = self.x_pred_control[self.time_variable_name].values
376+
plot_xY(
377+
time_points,
378+
self.y_pred_control.posterior_predictive.y_hat,
379+
ax=ax,
380+
plot_hdi_kwargs={"color": "C0"},
381+
)
382382

383383
# Plot model fit to treatment group
384-
parts = ax.violinplot(
385-
az.extract(
386-
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
387-
).values.T,
388-
positions=self.x_pred_treatment[self.time_variable_name].values,
389-
showmeans=False,
390-
showmedians=False,
391-
widths=0.2,
392-
)
393-
394-
for pc in parts["bodies"]:
395-
pc.set_facecolor("C1")
396-
pc.set_edgecolor("None")
397-
pc.set_alpha(0.5)
384+
time_points = self.x_pred_control[self.time_variable_name].values
385+
plot_xY(
386+
time_points,
387+
self.y_pred_treatment.posterior_predictive.y_hat,
388+
ax=ax,
389+
plot_hdi_kwargs={"color": "C1"},
390+
)
391+
398392
# Plot counterfactual - post-test for treatment group IF no treatment
399393
# had occurred.
400-
parts = ax.violinplot(
401-
az.extract(
402-
self.y_pred_counterfactual,
403-
group="posterior_predictive",
404-
var_names="mu",
405-
).values.T,
406-
positions=self.x_pred_counterfactual[self.time_variable_name].values,
407-
showmeans=False,
408-
showmedians=False,
409-
widths=0.2,
410-
)
411-
for pc in parts["bodies"]:
412-
pc.set_facecolor("C2")
413-
pc.set_edgecolor("None")
414-
pc.set_alpha(0.5)
394+
time_points = self.x_pred_counterfactual[self.time_variable_name].values
395+
if len(time_points) == 1:
396+
parts = ax.violinplot(
397+
az.extract(
398+
self.y_pred_counterfactual,
399+
group="posterior_predictive",
400+
var_names="mu",
401+
).values.T,
402+
positions=self.x_pred_counterfactual[self.time_variable_name].values,
403+
showmeans=False,
404+
showmedians=False,
405+
widths=0.2,
406+
)
407+
for pc in parts["bodies"]:
408+
pc.set_facecolor("C2")
409+
pc.set_edgecolor("None")
410+
pc.set_alpha(0.5)
411+
else:
412+
plot_xY(
413+
time_points,
414+
self.y_pred_counterfactual.posterior_predictive.y_hat,
415+
ax=ax,
416+
plot_hdi_kwargs={"color": "C2"},
417+
)
418+
415419
# arrow to label the causal impact
416420
self._plot_causal_impact_arrow(ax)
417421
# formatting

causalpy/skl_experiments.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,30 +190,30 @@ def __init__(
190190
self.y, self.X = np.asarray(y), np.asarray(X)
191191
self.outcome_variable_name = y.design_info.column_names[0]
192192

193-
# TODO: `treated` is a deterministic function of group and time, so this should
194-
# be a function rather than supplied data
195-
196193
# fit the model to all the data
197194
self.prediction_model.fit(X=self.X, y=self.y)
198195

199196
# predicted outcome for control group
200197
self.x_pred_control = pd.DataFrame(
201-
{"group": [0, 0], "t": [0.0, 1.0], "treated": [0, 0]}
198+
{"group": [0, 0], "t": [0.0, 1.0], "post_treatment": [0, 0]}
202199
)
200+
assert not self.x_pred_control.empty
203201
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
204202
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))
205203

206204
# predicted outcome for treatment group
207205
self.x_pred_treatment = pd.DataFrame(
208-
{"group": [1, 1], "t": [0.0, 1.0], "treated": [0, 1]}
206+
{"group": [1, 1], "t": [0.0, 1.0], "post_treatment": [0, 1]}
209207
)
208+
assert not self.x_pred_treatment.empty
210209
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
211210
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))
212211

213212
# predicted outcome for counterfactual
214213
self.x_pred_counterfactual = pd.DataFrame(
215-
{"group": [1], "t": [1.0], "treated": [0]}
214+
{"group": [1], "t": [1.0], "post_treatment": [0]}
216215
)
216+
assert not self.x_pred_counterfactual.empty
217217
(new_x,) = build_design_matrices(
218218
[self._x_design_info], self.x_pred_counterfactual
219219
)

0 commit comments

Comments
 (0)