Skip to content

Commit afa2e59

Browse files
committed
add input validation tests for PrePostNEGD
1 parent bcb2512 commit afa2e59

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

causalpy/pymc_experiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ def __init__(
747747
def _input_validation(self):
748748
"""Validate the input data and model formula for correctness"""
749749
if not _series_has_2_levels(self.data[self.group_variable_name]):
750-
raise ValueError(
750+
raise DataException(
751751
f"""
752752
There must be 2 levels of the grouping variable
753753
{self.group_variable_name}. I.e. the treated and untreated.

causalpy/tests/test_input_validation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,26 @@ def test_sc_brexit_input_error():
141141
formula=formula,
142142
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
143143
)
144+
145+
146+
# Pre-post NEGD
147+
148+
149+
def test_ancova_validation_2_levels():
150+
"""Test that we get a DataException group variable is not dummy coded"""
151+
df = pd.DataFrame(
152+
{
153+
"group": [0, 0, 1, 2],
154+
"pre": [1, 1, 3, 4],
155+
"post": [1, 2, 3, 4],
156+
}
157+
)
158+
159+
with pytest.raises(DataException):
160+
_ = cp.pymc_experiments.PrePostNEGD(
161+
df,
162+
formula="post ~ 1 + C(group) + pre",
163+
group_variable_name="group",
164+
pretreatment_variable_name="pre",
165+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
166+
)

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)