Skip to content

Commit b5a1181

Browse files
committed
add posterior predictions to the plot + improve robustness of finding treatment coeff
1 parent 47974a8 commit b5a1181

File tree

2 files changed

+206
-61
lines changed

2 files changed

+206
-61
lines changed

causalpy/pymc_experiments.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def __init__(
614614
data: pd.DataFrame,
615615
formula: str,
616616
group_variable_name: str,
617+
pretreatment_variable_name: str,
617618
prediction_model=None,
618619
**kwargs,
619620
):
@@ -622,6 +623,7 @@ def __init__(
622623
self.expt_type = "Pretest/posttest Nonequivalent Group Design"
623624
self.formula = formula
624625
self.group_variable_name = group_variable_name
626+
self.pretreatment_variable_name = pretreatment_variable_name
625627

626628
y, X = dmatrices(formula, self.data)
627629
self._y_design_info = y.design_info
@@ -645,6 +647,33 @@ def __init__(
645647
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
646648
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
647649

650+
# Calculate the posterior predictive for the treatment and control for an
651+
# interpolated set of pretest values
652+
# get the model predictions of the observed data
653+
self.pred_xi = np.linspace(
654+
np.min(self.data[self.pretreatment_variable_name]),
655+
np.max(self.data[self.pretreatment_variable_name]),
656+
200,
657+
)
658+
# untreated
659+
x_pred_untreated = pd.DataFrame(
660+
{
661+
self.pretreatment_variable_name: self.pred_xi,
662+
self.group_variable_name: np.zeros(self.pred_xi.shape),
663+
}
664+
)
665+
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
666+
self.pred_untreated = self.prediction_model.predict(X=np.asarray(new_x))
667+
# treated
668+
x_pred_untreated = pd.DataFrame(
669+
{
670+
self.pretreatment_variable_name: self.pred_xi,
671+
self.group_variable_name: np.ones(self.pred_xi.shape),
672+
}
673+
)
674+
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
675+
self.pred_treated = self.prediction_model.predict(X=np.asarray(new_x))
676+
648677
# Evaluate causal impact as equal to the trestment effect
649678
self.causal_impact = self.prediction_model.idata.posterior["beta"].sel(
650679
{"coeffs": self._get_treatment_effect_coeff()}
@@ -668,6 +697,23 @@ def plot(self):
668697
ax=ax[0],
669698
)
670699
ax[0].set(xlabel="Pretest", ylabel="Posttest")
700+
701+
# plot posterior predictive of untreated
702+
plot_xY(
703+
self.pred_xi,
704+
self.pred_untreated["posterior_predictive"].y_hat,
705+
ax=ax[0],
706+
plot_hdi_kwargs={"color": "C0"},
707+
)
708+
709+
# plot posterior predictive of treated
710+
plot_xY(
711+
self.pred_xi,
712+
self.pred_treated["posterior_predictive"].y_hat,
713+
ax=ax[0],
714+
plot_hdi_kwargs={"color": "C1"},
715+
)
716+
671717
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
672718

673719
# Plot estimated caual impact / treatment effect
@@ -698,5 +744,7 @@ def _get_treatment_effect_coeff(self) -> str:
698744
the labels are `['Intercept', 'C(group)[T.1]', 'pre']`
699745
then we want `C(group)[T.1]`.
700746
"""
701-
mask = [self.group_variable_name in label for label in self.labels]
702-
return self.labels[np.argmax(mask)]
747+
for label in self.labels:
748+
if ("group" in label) & (":" not in label):
749+
return label
750+
# TODO: raise an exception if we fail to find the coefficient we want

0 commit comments

Comments
 (0)