Skip to content

Commit 63d8466

Browse files
committed
#24 DiD plot posterior expectation instead of posterior predictive
1 parent 3a11f25 commit 63d8466

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

causalpy/pymc_experiments.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,12 @@ def __init__(
256256
self.y_pred_counterfactual = self.prediction_model.predict(np.asarray(new_x))
257257

258258
# calculate causal impact
259-
# TODO: This should most likely be posterior estimate, not posterior predictive
260259
self.causal_impact = (
261260
self.y_pred_treatment["posterior_predictive"]
262-
.y_hat.isel({"obs_ind": 1})
261+
.mu.isel({"obs_ind": 1})
263262
.mean()
264263
.data
265-
- self.y_pred_counterfactual["posterior_predictive"].y_hat.mean().data
264+
- self.y_pred_counterfactual["posterior_predictive"].mu.mean().data
266265
)
267266

268267
def plot(self):
@@ -283,7 +282,7 @@ def plot(self):
283282
# Plot model fit to control group
284283
parts = ax.violinplot(
285284
az.extract(
286-
self.y_pred_control, group="posterior_predictive", var_names="y_hat"
285+
self.y_pred_control, group="posterior_predictive", var_names="mu"
287286
).values.T,
288287
positions=self.x_pred_control[self.time_variable_name].values,
289288
showmeans=False,
@@ -298,7 +297,7 @@ def plot(self):
298297
# Plot model fit to treatment group
299298
parts = ax.violinplot(
300299
az.extract(
301-
self.y_pred_treatment, group="posterior_predictive", var_names="y_hat"
300+
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
302301
).values.T,
303302
positions=self.x_pred_treatment[self.time_variable_name].values,
304303
showmeans=False,
@@ -310,7 +309,7 @@ def plot(self):
310309
az.extract(
311310
self.y_pred_counterfactual,
312311
group="posterior_predictive",
313-
var_names="y_hat",
312+
var_names="mu",
314313
).values.T,
315314
positions=self.x_pred_counterfactual[self.time_variable_name].values,
316315
showmeans=False,
@@ -320,12 +319,12 @@ def plot(self):
320319
# arrow to label the causal impact
321320
y_pred_treatment = (
322321
self.y_pred_treatment["posterior_predictive"]
323-
.y_hat.isel({"obs_ind": 1})
322+
.mu.isel({"obs_ind": 1})
324323
.mean()
325324
.data
326325
)
327326
y_pred_counterfactual = (
328-
self.y_pred_counterfactual["posterior_predictive"].y_hat.mean().data
327+
self.y_pred_counterfactual["posterior_predictive"].mu.mean().data
329328
)
330329
ax.annotate(
331330
"",

0 commit comments

Comments
 (0)