Skip to content

Commit 7fbb27a

Browse files
Rojan ShresthaRojan Shrestha
authored andcommitted
Refactor DiD validation: segregate FormulaException and DataException
1 parent 4ebe1a7 commit 7fbb27a

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from causalpy.custom_exceptions import (
2828
DataException,
29+
FormulaException,
2930
)
3031
from causalpy.plot_utils import plot_xY
3132
from causalpy.pymc_models import PyMCModel
@@ -233,27 +234,40 @@ def __init__(
233234

234235
def input_validation(self):
235236
"""Validate the input data and model formula for correctness"""
236-
if (
237-
self.post_treatment_variable_name not in self.formula
238-
or self.post_treatment_variable_name not in self.data.columns
239-
):
237+
# Check if post_treatment_variable_name is in formula
238+
if self.post_treatment_variable_name not in self.formula:
239+
if self.post_treatment_variable_name == "post_treatment":
240+
# Default case - user didn't specify custom name, so guide them to use "post_treatment"
241+
raise FormulaException(
242+
"Missing 'post_treatment' in formula.\n"
243+
"Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n"
244+
"Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment').\n"
245+
"Or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'."
246+
)
247+
else:
248+
# Custom case - user specified custom name, so remind them what they specified
249+
raise FormulaException(
250+
f"Missing required variable '{self.post_treatment_variable_name}' in formula.\n\n"
251+
f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', "
252+
f"please ensure formula includes '{self.post_treatment_variable_name}'"
253+
)
254+
255+
# Check if post_treatment_variable_name is in data columns
256+
if self.post_treatment_variable_name not in self.data.columns:
240257
if self.post_treatment_variable_name == "post_treatment":
241258
# Default case - user didn't specify custom name, so guide them to use "post_treatment"
242259
raise DataException(
243-
"Missing 'post_treatment' in formula or dataset.\n"
260+
"Missing 'post_treatment' column in dataset.\n"
244261
"Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n"
245-
"1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n"
246-
"2) and ensure dataset has boolean column 'post_treatment'.\n"
247-
"To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'."
262+
"Ensure dataset has boolean column 'post_treatment'.\n"
263+
"or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'."
248264
)
249265
else:
250266
# Custom case - user specified custom name, so remind them what they specified
251267
raise DataException(
252-
f"Missing required variable '{self.post_treatment_variable_name}' in formula or dataset.\n\n"
268+
f"Missing required column '{self.post_treatment_variable_name}' in dataset.\n\n"
253269
f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', "
254-
f"please ensure:\n"
255-
f"1) formula includes '{self.post_treatment_variable_name}'\n"
256-
f"2) dataset has boolean column named '{self.post_treatment_variable_name}'"
270+
f"please ensure dataset has boolean column named '{self.post_treatment_variable_name}'"
257271
)
258272

259273
if "unit" not in self.data.columns:

0 commit comments

Comments
 (0)