Skip to content

Commit ffc9b44

Browse files
committed
add tests for exceptions in input validation + remove redundant comment
1 parent 714d13d commit ffc9b44

File tree

3 files changed

+101
-6
lines changed

3 files changed

+101
-6
lines changed

causalpy/pymc_experiments.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,9 +746,6 @@ def __init__(
746746

747747
def _input_validation(self):
748748
"""Validate the input data and model formula for correctness"""
749-
# Check that `group_variable_name` has TWO levels, representing the
750-
# treated/untreated. But it does not matter what the actual names of
751-
# the levels are.
752749
if not _series_has_2_levels(self.data[self.group_variable_name]):
753750
raise ValueError(
754751
f"""
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pandas as pd
2+
import pytest
3+
4+
import causalpy as cp
5+
from causalpy.custom_exceptions import DataException, FormulaException
6+
7+
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
8+
9+
# DiD
10+
11+
12+
def test_did_validation_post_treatment_formula():
13+
"""Test that we get a FormulaException if do not include post_treatment in the
14+
formula"""
15+
df = pd.DataFrame(
16+
{
17+
"group": [0, 0, 1, 1],
18+
"t": [0, 1, 0, 1],
19+
"unit": [0, 0, 1, 1],
20+
"post_treatment": [0, 1, 0, 1],
21+
"y": [1, 2, 3, 4],
22+
}
23+
)
24+
25+
with pytest.raises(FormulaException):
26+
_ = cp.pymc_experiments.DifferenceInDifferences(
27+
df,
28+
formula="y ~ 1 + group*post_SOMETHING",
29+
time_variable_name="t",
30+
group_variable_name="group",
31+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
32+
)
33+
34+
35+
def test_did_validation_post_treatment_data():
36+
"""Test that we get a DataException if do not include post_treatment in the data"""
37+
df = pd.DataFrame(
38+
{
39+
"group": [0, 0, 1, 1],
40+
"t": [0, 1, 0, 1],
41+
"unit": [0, 0, 1, 1],
42+
# "post_treatment": [0, 1, 0, 1],
43+
"y": [1, 2, 3, 4],
44+
}
45+
)
46+
47+
with pytest.raises(DataException):
48+
_ = cp.pymc_experiments.DifferenceInDifferences(
49+
df,
50+
formula="y ~ 1 + group*post_treatment",
51+
time_variable_name="t",
52+
group_variable_name="group",
53+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
54+
)
55+
56+
57+
def test_did_validation_unit_data():
58+
"""Test that we get a DataException if do not include unit in the data"""
59+
df = pd.DataFrame(
60+
{
61+
"group": [0, 0, 1, 1],
62+
"t": [0, 1, 0, 1],
63+
# "unit": [0, 0, 1, 1],
64+
"post_treatment": [0, 1, 0, 1],
65+
"y": [1, 2, 3, 4],
66+
}
67+
)
68+
69+
with pytest.raises(DataException):
70+
_ = cp.pymc_experiments.DifferenceInDifferences(
71+
df,
72+
formula="y ~ 1 + group*post_treatment",
73+
time_variable_name="t",
74+
group_variable_name="group",
75+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
76+
)
77+
78+
79+
def test_did_validation_group_dummy_coded():
80+
"""Test that we get a DataException if the group variable is not dummy coded"""
81+
df = pd.DataFrame(
82+
{
83+
"group": ["a", "a", "b", "b"],
84+
"t": [0, 1, 0, 1],
85+
"unit": [0, 0, 1, 1],
86+
"post_treatment": [0, 1, 0, 1],
87+
"y": [1, 2, 3, 4],
88+
}
89+
)
90+
91+
with pytest.raises(DataException):
92+
_ = cp.pymc_experiments.DifferenceInDifferences(
93+
df,
94+
formula="y ~ 1 + group*post_treatment",
95+
time_variable_name="t",
96+
group_variable_name="group",
97+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
98+
)

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)