Skip to content

Commit 15ef4ec

Browse files
committed
first stab at regression kink design
1 parent 0ef986a commit 15ef4ec

File tree

2 files changed

+179
-3
lines changed

2 files changed

+179
-3
lines changed

causalpy/pymc_experiments.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,182 @@ def summary(self) -> None:
957957
self.print_coefficients()
958958

959959

960+
class RegressionKink(ExperimentalDesign):
961+
def __init__(
962+
self,
963+
data: pd.DataFrame,
964+
formula: str,
965+
kink_point: float,
966+
model=None,
967+
running_variable_name: str = "x",
968+
epsilon: float = 0.001,
969+
bandwidth: Optional[float] = None,
970+
**kwargs,
971+
):
972+
super().__init__(model=model, **kwargs)
973+
self.expt_type = "Regression Discontinuity"
974+
self.data = data
975+
self.formula = formula
976+
self.running_variable_name = running_variable_name
977+
self.kink_point = kink_point
978+
self.epsilon = epsilon
979+
self.bandwidth = bandwidth
980+
self._input_validation()
981+
982+
if self.bandwidth is not None:
983+
fmin = self.kink_point - self.bandwidth
984+
fmax = self.kink_point + self.bandwidth
985+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
986+
if len(filtered_data) <= 10:
987+
warnings.warn(
988+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
989+
UserWarning,
990+
)
991+
y, X = dmatrices(formula, filtered_data)
992+
else:
993+
y, X = dmatrices(formula, self.data)
994+
995+
self._y_design_info = y.design_info
996+
self._x_design_info = X.design_info
997+
self.labels = X.design_info.column_names
998+
self.y, self.X = np.asarray(y), np.asarray(X)
999+
self.outcome_variable_name = y.design_info.column_names[0]
1000+
1001+
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
1002+
self.model.fit(X=self.X, y=self.y, coords=COORDS)
1003+
1004+
# score the goodness of fit to all data
1005+
self.score = self.model.score(X=self.X, y=self.y)
1006+
1007+
# get the model predictions of the observed data
1008+
if self.bandwidth is not None:
1009+
xi = np.linspace(fmin, fmax, 200)
1010+
else:
1011+
xi = np.linspace(
1012+
np.min(self.data[self.running_variable_name]),
1013+
np.max(self.data[self.running_variable_name]),
1014+
200,
1015+
)
1016+
self.x_pred = pd.DataFrame(
1017+
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
1018+
)
1019+
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
1020+
self.pred = self.model.predict(X=np.asarray(new_x))
1021+
1022+
# Calculate the change in gradient by evaluating the function below the kink
1023+
# point, at the kink point, and above the kink point.
1024+
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above
1025+
# (not below) the threshold
1026+
self.x_discon = pd.DataFrame(
1027+
{
1028+
self.running_variable_name: np.array(
1029+
[
1030+
self.kink_point - self.epsilon,
1031+
self.kink_point,
1032+
self.kink_point + self.epsilon,
1033+
]
1034+
),
1035+
"treated": np.array([0, 1, 1]),
1036+
}
1037+
)
1038+
(new_x,) = build_design_matrices([self._x_design_info], self.x_discon)
1039+
self.pred_discon = self.model.predict(X=np.asarray(new_x))
1040+
1041+
self.gradient_left = (
1042+
self.pred_discon["posterior_predictive"].sel(obs_ind=1)["mu"]
1043+
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
1044+
)
1045+
self.gradient_right = (
1046+
self.pred_discon["posterior_predictive"].sel(obs_ind=2)["mu"]
1047+
- self.pred_discon["posterior_predictive"].sel(obs_ind=1)["mu"]
1048+
)
1049+
self.gradient_change = self.gradient_right - self.gradient_left
1050+
1051+
def _input_validation(self):
1052+
"""Validate the input data and model formula for correctness"""
1053+
# if "treated" not in self.formula:
1054+
# raise FormulaException(
1055+
# "A predictor called `treated` should be in the formula"
1056+
# )
1057+
1058+
if _is_variable_dummy_coded(self.data["treated"]) is False:
1059+
raise DataException(
1060+
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
1061+
)
1062+
1063+
def _is_treated(self, x):
1064+
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
1065+
1066+
.. warning::
1067+
1068+
Assumes treatment is given to those ABOVE the treatment threshold.
1069+
"""
1070+
return np.greater_equal(x, self.kink_point)
1071+
1072+
def plot(self):
1073+
"""
1074+
Plot the results
1075+
"""
1076+
fig, ax = plt.subplots()
1077+
# Plot raw data
1078+
sns.scatterplot(
1079+
self.data,
1080+
x=self.running_variable_name,
1081+
y=self.outcome_variable_name,
1082+
c="k", # hue="treated",
1083+
ax=ax,
1084+
)
1085+
1086+
# Plot model fit to data
1087+
h_line, h_patch = plot_xY(
1088+
self.x_pred[self.running_variable_name],
1089+
self.pred["posterior_predictive"].mu,
1090+
ax=ax,
1091+
plot_hdi_kwargs={"color": "C1"},
1092+
)
1093+
handles = [(h_line, h_patch)]
1094+
labels = ["Posterior mean"]
1095+
1096+
# create strings to compose title
1097+
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
1098+
r2 = f"Bayesian $R^2$ on all data = {title_info}"
1099+
percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values
1100+
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
1101+
grad_change = f"""
1102+
Change in gradient = {self.gradient_change.mean():.2f},
1103+
"""
1104+
ax.set(title=r2 + "\n" + grad_change + ci)
1105+
# Intervention line
1106+
ax.axvline(
1107+
x=self.kink_point,
1108+
ls="-",
1109+
lw=3,
1110+
color="r",
1111+
label="treatment threshold",
1112+
)
1113+
ax.legend(
1114+
handles=(h_tuple for h_tuple in handles),
1115+
labels=labels,
1116+
fontsize=LEGEND_FONT_SIZE,
1117+
)
1118+
return (fig, ax)
1119+
1120+
def summary(self) -> None:
1121+
"""
1122+
Print text output summarising the results
1123+
"""
1124+
1125+
print(f"{self.expt_type:=^80}")
1126+
print(f"Formula: {self.formula}")
1127+
print(f"Running variable: {self.running_variable_name}")
1128+
print(f"Threshold on running variable: {self.kink_point}")
1129+
print("\nResults:")
1130+
print(
1131+
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
1132+
)
1133+
self.print_coefficients()
1134+
1135+
9601136
class PrePostNEGD(ExperimentalDesign):
9611137
"""
9621138
A class to analyse data from pretest/posttest designs

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)