Skip to content

Commit c232d89

Browse files
Rojan ShresthaRojan Shrestha
authored andcommitted
added validations for interactions, test coverage expanded to test interaction terms,more generic messages
1 parent 7fbb27a commit c232d89

File tree

3 files changed

+186
-35
lines changed

3 files changed

+186
-35
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Difference in differences
1616
"""
1717

18+
import re
19+
1820
import arviz as az
1921
import numpy as np
2022
import pandas as pd
@@ -233,42 +235,21 @@ def __init__(
233235
return
234236

235237
def input_validation(self):
238+
# Validate formula structure and interaction interaction terms
239+
self._validate_formula_interaction_terms()
240+
236241
"""Validate the input data and model formula for correctness"""
237242
# Check if post_treatment_variable_name is in formula
238243
if self.post_treatment_variable_name not in self.formula:
239-
if self.post_treatment_variable_name == "post_treatment":
240-
# Default case - user didn't specify custom name, so guide them to use "post_treatment"
241-
raise FormulaException(
242-
"Missing 'post_treatment' in formula.\n"
243-
"Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n"
244-
"Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment').\n"
245-
"Or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'."
246-
)
247-
else:
248-
# Custom case - user specified custom name, so remind them what they specified
249-
raise FormulaException(
250-
f"Missing required variable '{self.post_treatment_variable_name}' in formula.\n\n"
251-
f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', "
252-
f"please ensure formula includes '{self.post_treatment_variable_name}'"
253-
)
244+
raise FormulaException(
245+
f"Missing required variable '{self.post_treatment_variable_name}' in formula"
246+
)
254247

255248
# Check if post_treatment_variable_name is in data columns
256249
if self.post_treatment_variable_name not in self.data.columns:
257-
if self.post_treatment_variable_name == "post_treatment":
258-
# Default case - user didn't specify custom name, so guide them to use "post_treatment"
259-
raise DataException(
260-
"Missing 'post_treatment' column in dataset.\n"
261-
"Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n"
262-
"Ensure dataset has boolean column 'post_treatment'.\n"
263-
"or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'."
264-
)
265-
else:
266-
# Custom case - user specified custom name, so remind them what they specified
267-
raise DataException(
268-
f"Missing required column '{self.post_treatment_variable_name}' in dataset.\n\n"
269-
f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', "
270-
f"please ensure dataset has boolean column named '{self.post_treatment_variable_name}'"
271-
)
250+
raise DataException(
251+
f"Missing required column '{self.post_treatment_variable_name}' in dataset"
252+
)
272253

273254
if "unit" not in self.data.columns:
274255
raise DataException(
@@ -281,6 +262,61 @@ def input_validation(self):
281262
coded. Consisting of 0's and 1's only."""
282263
)
283264

265+
def _get_interaction_terms(self):
266+
"""
267+
Extract interaction terms from the formula.
268+
Returns a list of interaction terms (those with '*' or ':').
269+
"""
270+
# Define interaction indicators
271+
INTERACTION_INDICATORS = ["*", ":"]
272+
273+
# Remove whitespace
274+
formula = self.formula.replace(" ", "")
275+
276+
# Extract right-hand side of the formula
277+
rhs = formula.split("~")[1]
278+
279+
# Split terms by '+' or '-' while keeping them intact
280+
terms = re.split(r"(?=[+-])", rhs)
281+
282+
# Clean up terms and get interaction terms (those with '*' or ':')
283+
interaction_terms = []
284+
for term in terms:
285+
# Remove leading + or - for processing
286+
clean_term = term.lstrip("+-")
287+
if any(indicator in clean_term for indicator in INTERACTION_INDICATORS):
288+
interaction_terms.append(clean_term)
289+
290+
return interaction_terms
291+
292+
def _validate_formula_interaction_terms(self):
293+
"""
294+
Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
295+
Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
296+
"""
297+
# Define interaction indicators
298+
INTERACTION_INDICATORS = ["*", ":"]
299+
300+
# Get interaction terms
301+
interaction_terms = self._get_interaction_terms()
302+
303+
# Check for interaction terms with more than 2 variables (more than one '*' or ':')
304+
for term in interaction_terms:
305+
total_indicators = sum(
306+
term.count(indicator) for indicator in INTERACTION_INDICATORS
307+
)
308+
if (
309+
total_indicators >= 2
310+
): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols)
311+
raise FormulaException(
312+
f"Formula contains interaction term with more than 2 variables: {term}. Only two-way interactions are allowed."
313+
)
314+
315+
if len(interaction_terms) > 1:
316+
raise FormulaException(
317+
f"Formula contains more than 1 interaction term: {interaction_terms}. Maximum of 1 allowed."
318+
)
319+
284320
def summary(self, round_to=None) -> None:
285321
"""Print summary of main results and model coefficients.
286322

causalpy/tests/test_input_validation.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,29 @@
3030

3131

3232
def test_did_validation_post_treatment_formula():
33-
"""Test that we get a FormulaException if do not include post_treatment in the
34-
formula"""
33+
"""Test that we get a FormulaException for invalid formulas and missing post_treatment variables"""
3534
df = pd.DataFrame(
3635
{
3736
"group": [0, 0, 1, 1],
3837
"t": [0, 1, 0, 1],
3938
"unit": [0, 0, 1, 1],
4039
"post_treatment": [0, 1, 0, 1],
40+
"male": [0, 1, 0, 1], # Additional variable for testing
4141
"y": [1, 2, 3, 4],
4242
}
4343
)
4444

45+
df_with_custom = pd.DataFrame(
46+
{
47+
"group": [0, 0, 1, 1],
48+
"t": [0, 1, 0, 1],
49+
"unit": [0, 0, 1, 1],
50+
"custom_post": [0, 1, 0, 1], # Custom column name
51+
"y": [1, 2, 3, 4],
52+
}
53+
)
54+
55+
# Test 1: Missing post_treatment variable in formula
4556
with pytest.raises(FormulaException):
4657
_ = cp.DifferenceInDifferences(
4758
df,
@@ -51,6 +62,7 @@ def test_did_validation_post_treatment_formula():
5162
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
5263
)
5364

65+
# Test 2: Missing post_treatment variable in formula (duplicate test)
5466
with pytest.raises(FormulaException):
5567
_ = cp.DifferenceInDifferences(
5668
df,
@@ -60,6 +72,88 @@ def test_did_validation_post_treatment_formula():
6072
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
6173
)
6274

75+
# Test 3: Custom post_treatment_variable_name but formula uses different name
76+
with pytest.raises(FormulaException):
77+
_ = cp.DifferenceInDifferences(
78+
df_with_custom,
79+
formula="y ~ 1 + group*post_treatment", # Formula uses 'post_treatment'
80+
time_variable_name="t",
81+
group_variable_name="group",
82+
post_treatment_variable_name="custom_post", # But user specifies 'custom_post'
83+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
84+
)
85+
86+
# Test 4: Default post_treatment_variable_name but formula uses different name
87+
with pytest.raises(FormulaException):
88+
_ = cp.DifferenceInDifferences(
89+
df,
90+
formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post'
91+
time_variable_name="t",
92+
group_variable_name="group",
93+
# post_treatment_variable_name defaults to "post_treatment"
94+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
95+
)
96+
97+
# Test 5: Repeated interaction terms (should be invalid)
98+
with pytest.raises(FormulaException):
99+
_ = cp.DifferenceInDifferences(
100+
df,
101+
formula="y ~ 1 + group + group*post_treatment + group*post_treatment",
102+
time_variable_name="t",
103+
group_variable_name="group",
104+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
105+
)
106+
107+
# Test 6: Three-way interactions using * (should be invalid)
108+
with pytest.raises(FormulaException):
109+
_ = cp.DifferenceInDifferences(
110+
df,
111+
formula="y ~ 1 + group + group*post_treatment*male",
112+
time_variable_name="t",
113+
group_variable_name="group",
114+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
115+
)
116+
117+
# Test 7: Three-way interactions using : (should be invalid)
118+
with pytest.raises(FormulaException):
119+
_ = cp.DifferenceInDifferences(
120+
df,
121+
formula="y ~ 1 + group + group:post_treatment:male",
122+
time_variable_name="t",
123+
group_variable_name="group",
124+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
125+
)
126+
127+
# Test 8: Multiple different interaction terms using * (should be invalid)
128+
with pytest.raises(FormulaException):
129+
_ = cp.DifferenceInDifferences(
130+
df,
131+
formula="y ~ 1 + group + group*post_treatment + group*male",
132+
time_variable_name="t",
133+
group_variable_name="group",
134+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
135+
)
136+
137+
# Test 9: Multiple different interaction terms using : (should be invalid)
138+
with pytest.raises(FormulaException):
139+
_ = cp.DifferenceInDifferences(
140+
df,
141+
formula="y ~ 1 + group + group:post_treatment + group:male",
142+
time_variable_name="t",
143+
group_variable_name="group",
144+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
145+
)
146+
147+
# Test 10: Mixed issues - multiple terms + three-way interaction (should be invalid)
148+
with pytest.raises(FormulaException):
149+
_ = cp.DifferenceInDifferences(
150+
df,
151+
formula="y ~ 1 + group + group*post_treatment + group:post_treatment:male",
152+
time_variable_name="t",
153+
group_variable_name="group",
154+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
155+
)
156+
63157

64158
def test_did_validation_post_treatment_data():
65159
"""Test that we get a DataException if do not include post_treatment in the data"""
@@ -91,6 +185,27 @@ def test_did_validation_post_treatment_data():
91185
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
92186
)
93187

188+
# Test 2: Custom post_treatment_variable_name but column doesn't exist in data
189+
df_with_post = pd.DataFrame(
190+
{
191+
"group": [0, 0, 1, 1],
192+
"t": [0, 1, 0, 1],
193+
"unit": [0, 0, 1, 1],
194+
"post_treatment": [0, 1, 0, 1], # Data has 'post_treatment'
195+
"y": [1, 2, 3, 4],
196+
}
197+
)
198+
199+
with pytest.raises(DataException):
200+
_ = cp.DifferenceInDifferences(
201+
df_with_post,
202+
formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post'
203+
time_variable_name="t",
204+
group_variable_name="group",
205+
post_treatment_variable_name="custom_post", # User specifies 'custom_post'
206+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
207+
)
208+
94209

95210
def test_did_validation_unit_data():
96211
"""Test that we get a DataException if do not include unit in the data"""

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)