Skip to content

Commit 0752be5

Browse files
committed
add input validation and associated tests for exceptions for RegressionDiscontinuity
1 parent 06a4aa1 commit 0752be5

File tree

3 files changed

+58
-6
lines changed

3 files changed

+58
-6
lines changed

causalpy/pymc_experiments.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,16 +560,15 @@ def __init__(
560560
self.formula = formula
561561
self.running_variable_name = running_variable_name
562562
self.treatment_threshold = treatment_threshold
563+
self._input_validation()
564+
563565
y, X = dmatrices(formula, self.data)
564566
self._y_design_info = y.design_info
565567
self._x_design_info = X.design_info
566568
self.labels = X.design_info.column_names
567569
self.y, self.X = np.asarray(y), np.asarray(X)
568570
self.outcome_variable_name = y.design_info.column_names[0]
569571

570-
# TODO: `treated` is a deterministic function of x and treatment_threshold, so
571-
# this could be a function rather than supplied data
572-
573572
# DEVIATION FROM SKL EXPERIMENT CODE =============================
574573
# fit the model to the observed (pre-intervention) data
575574
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
@@ -610,6 +609,18 @@ def __init__(
610609
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
611610
)
612611

612+
def _input_validation(self):
613+
"""Validate the input data and model formula for correctness"""
614+
if "treated" not in self.formula:
615+
raise FormulaException(
616+
"A predictor called `treated` should be in the formula"
617+
)
618+
619+
if _is_variable_dummy_coded(self.data["treated"]) is False:
620+
raise DataException(
621+
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
622+
)
623+
613624
def _is_treated(self, x):
614625
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
615626

causalpy/tests/test_input_validation.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,44 @@ def test_ancova_validation_2_levels():
161161
pretreatment_variable_name="pre",
162162
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
163163
)
164+
165+
166+
# Regression discontinuity
167+
168+
169+
def test_rd_validation_treated_in_formula():
170+
"""Test that we get a FormulaException if treated is not in the model formula"""
171+
df = pd.DataFrame(
172+
{
173+
"x": [0, 1, 2, 3],
174+
"treated": [0, 0, 1, 1],
175+
"y": [1, 1, 2, 2],
176+
}
177+
)
178+
179+
with pytest.raises(FormulaException):
180+
_ = cp.pymc_experiments.RegressionDiscontinuity(
181+
df,
182+
formula="y ~ 1 + x",
183+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
184+
treatment_threshold=0.5,
185+
)
186+
187+
188+
def test_rd_validation_treated_is_dummy():
189+
"""Test that we get a DataException if treated is not dummy coded"""
190+
df = pd.DataFrame(
191+
{
192+
"x": [0, 1, 2, 3],
193+
"treated": ["control", "control", "treated", "treated"],
194+
"y": [1, 1, 2, 2],
195+
}
196+
)
197+
198+
with pytest.raises(DataException):
199+
_ = cp.pymc_experiments.RegressionDiscontinuity(
200+
df,
201+
formula="y ~ 1 + x + treated",
202+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
203+
treatment_threshold=0.5,
204+
)

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)