|
26 | 26 |
|
27 | 27 | from causalpy.custom_exceptions import ( |
28 | 28 | DataException, |
| 29 | + FormulaException, |
29 | 30 | ) |
30 | 31 | from causalpy.plot_utils import plot_xY |
31 | 32 | from causalpy.pymc_models import PyMCModel |
@@ -233,27 +234,40 @@ def __init__( |
233 | 234 |
|
234 | 235 | def input_validation(self): |
235 | 236 | """Validate the input data and model formula for correctness""" |
236 | | - if ( |
237 | | - self.post_treatment_variable_name not in self.formula |
238 | | - or self.post_treatment_variable_name not in self.data.columns |
239 | | - ): |
| 237 | + # Check if post_treatment_variable_name is in formula |
| 238 | + if self.post_treatment_variable_name not in self.formula: |
| 239 | + if self.post_treatment_variable_name == "post_treatment": |
| 240 | + # Default case - user didn't specify custom name, so guide them to use "post_treatment" |
| 241 | + raise FormulaException( |
| 242 | + "Missing 'post_treatment' in formula.\n" |
| 243 | + "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" |
| 244 | + "Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment').\n" |
| 245 | + "Or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." |
| 246 | + ) |
| 247 | + else: |
| 248 | + # Custom case - user specified custom name, so remind them what they specified |
| 249 | + raise FormulaException( |
| 250 | + f"Missing required variable '{self.post_treatment_variable_name}' in formula.\n\n" |
| 251 | + f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " |
| 252 | + f"please ensure formula includes '{self.post_treatment_variable_name}'" |
| 253 | + ) |
| 254 | + |
| 255 | + # Check if post_treatment_variable_name is in data columns |
| 256 | + if self.post_treatment_variable_name not in self.data.columns: |
240 | 257 | if self.post_treatment_variable_name == "post_treatment": |
241 | 258 | # Default case - user didn't specify custom name, so guide them to use "post_treatment" |
242 | 259 | raise DataException( |
243 | | - "Missing 'post_treatment' in formula or dataset.\n" |
| 260 | + "Missing 'post_treatment' column in dataset.\n" |
244 | 261 | "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" |
245 | | - "1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n" |
246 | | - "2) and ensure dataset has boolean column 'post_treatment'.\n" |
247 | | - "To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." |
| 262 | + "Ensure dataset has boolean column 'post_treatment'.\n" |
| 263 | + "or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." |
248 | 264 | ) |
249 | 265 | else: |
250 | 266 | # Custom case - user specified custom name, so remind them what they specified |
251 | 267 | raise DataException( |
252 | | - f"Missing required variable '{self.post_treatment_variable_name}' in formula or dataset.\n\n" |
| 268 | + f"Missing required column '{self.post_treatment_variable_name}' in dataset.\n\n" |
253 | 269 | f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " |
254 | | - f"please ensure:\n" |
255 | | - f"1) formula includes '{self.post_treatment_variable_name}'\n" |
256 | | - f"2) dataset has boolean column named '{self.post_treatment_variable_name}'" |
| 270 | + f"please ensure dataset has boolean column named '{self.post_treatment_variable_name}'" |
257 | 271 | ) |
258 | 272 |
|
259 | 273 | if "unit" not in self.data.columns: |
|
0 commit comments