Skip to content

Commit e885cfe

Browse files
committed
improve plots + add summary method + centralise calculation of causal impact
1 parent 98f83b6 commit e885cfe

File tree

2 files changed

+96
-56
lines changed

2 files changed

+96
-56
lines changed

causalpy/pymc_experiments.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def __init__(
619619
):
620620
super().__init__(prediction_model=prediction_model, **kwargs)
621621
self.data = data
622-
self.expt_type = "Difference in Differences"
622+
self.expt_type = "Pretest/posttest Nonequivalent Group Design"
623623
self.formula = formula
624624
self.group_variable_name = group_variable_name
625625

@@ -644,6 +644,15 @@ def __init__(
644644
# fit the model to the observed (pre-intervention) data
645645
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
646646
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
647+
648+
# Evaluate causal impact as equal to the trestment effect
649+
# TODO: Grabbing the appropriate name for the group parameter (to describe
650+
# the treatment effect) is brittle and will likely break.
651+
treatment_effect_parameter = f"C({self.group_variable_name})[T.1]"
652+
self.causal_impact = self.prediction_model.idata.posterior["beta"].sel(
653+
{"coeffs": treatment_effect_parameter}
654+
)
655+
647656
# ================================================================
648657

649658
def plot(self):
@@ -658,26 +667,29 @@ def plot(self):
658667
y="post",
659668
hue="group",
660669
alpha=0.5,
661-
palette="muted",
662670
data=self.data,
663671
ax=ax[0],
664672
)
665673
ax[0].set(xlabel="Pretest", ylabel="Posttest")
666674
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
667675

668-
# Post estimated treatment effect
669-
# TODO: Grabbing the appropriate name for the group parameter (to describe
670-
# the treatment effect) is brittle and will likely break.
671-
treatment_effect_parameter = f"C({self.group_variable_name})[T.1]"
672-
az.plot_posterior(
673-
self.prediction_model.idata.posterior["beta"].sel(
674-
{"coeffs": treatment_effect_parameter}
675-
),
676-
ref_val=0,
677-
ax=ax[1],
678-
)
676+
# Plot estimated caual impact / treatment effect
677+
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1])
679678
ax[1].set(title="Estimated treatment effect")
680679
return fig, ax
681680

681+
def _causal_impact_summary_stat(self):
682+
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
683+
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
684+
causal_impact = f"{self.causal_impact.mean():.2f}, "
685+
return f"Causal impact = {causal_impact + ci}"
686+
682687
def summary(self):
683-
raise NotImplementedError
688+
"""Print text output summarising the results"""
689+
690+
print(f"{self.expt_type:=^80}")
691+
print(f"Formula: {self.formula}")
692+
print("\nResults:")
693+
# TODO: extra experiment specific outputs here
694+
print(self._causal_impact_summary_stat())
695+
self.print_coefficients()

docs/notebooks/ancova_pymc.ipynb

Lines changed: 70 additions & 42 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)