Skip to content

Commit 98f83b6

Browse files
committed
add ANCOVA analysis for pre/postest NEGD + example notebook
1 parent 022f8e8 commit 98f83b6

File tree

3 files changed

+478
-0
lines changed

3 files changed

+478
-0
lines changed

causalpy/pymc_experiments.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,84 @@ def summary(self):
600600
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
601601
)
602602
self.print_coefficients()
603+
604+
605+
# =============================================================================
606+
# =============================================================================
607+
608+
609+
class PrePostNEGD(ExperimentalDesign):
610+
"""A class to analyse data from pretest/posttest designs"""
611+
612+
def __init__(
613+
self,
614+
data: pd.DataFrame,
615+
formula: str,
616+
group_variable_name: str,
617+
prediction_model=None,
618+
**kwargs,
619+
):
620+
super().__init__(prediction_model=prediction_model, **kwargs)
621+
self.data = data
622+
self.expt_type = "Difference in Differences"
623+
self.formula = formula
624+
self.group_variable_name = group_variable_name
625+
626+
y, X = dmatrices(formula, self.data)
627+
self._y_design_info = y.design_info
628+
self._x_design_info = X.design_info
629+
self.labels = X.design_info.column_names
630+
self.y, self.X = np.asarray(y), np.asarray(X)
631+
self.outcome_variable_name = y.design_info.column_names[0]
632+
633+
# Input validation ----------------------------------------------------
634+
# Check that `group_variable_name` has TWO levels, representing the
635+
# treated/untreated. But it does not matter what the actual names of
636+
# the levels are.
637+
assert (
638+
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
639+
), f"""
640+
There must be 2 levels of the grouping variable {self.group_variable_name}
641+
.I.e. the treated and untreated.
642+
"""
643+
644+
# fit the model to the observed (pre-intervention) data
645+
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
646+
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
647+
# ================================================================
648+
649+
def plot(self):
650+
"""Plot the results"""
651+
fig, ax = plt.subplots(
652+
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
653+
)
654+
655+
# Plot raw data
656+
sns.scatterplot(
657+
x="pre",
658+
y="post",
659+
hue="group",
660+
alpha=0.5,
661+
palette="muted",
662+
data=self.data,
663+
ax=ax[0],
664+
)
665+
ax[0].set(xlabel="Pretest", ylabel="Posttest")
666+
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
667+
668+
# Post estimated treatment effect
669+
# TODO: Grabbing the appropriate name for the group parameter (to describe
670+
# the treatment effect) is brittle and will likely break.
671+
treatment_effect_parameter = f"C({self.group_variable_name})[T.1]"
672+
az.plot_posterior(
673+
self.prediction_model.idata.posterior["beta"].sel(
674+
{"coeffs": treatment_effect_parameter}
675+
),
676+
ref_val=0,
677+
ax=ax[1],
678+
)
679+
ax[1].set(title="Estimated treatment effect")
680+
return fig, ax
681+
682+
def summary(self):
683+
raise NotImplementedError

docs/examples.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
Nonequivalent group designs
2+
===========================
3+
4+
.. toctree::
5+
:titlesonly:
6+
7+
notebooks/ancova_pymc.ipynb
8+
9+
110
Synthetic Control
211
=================
312

docs/notebooks/ancova_pymc.ipynb

Lines changed: 388 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)