|
8 | 8 | import xarray as xr
|
9 | 9 | from patsy import build_design_matrices, dmatrices
|
10 | 10 |
|
11 |
| -from causalpy.custom_exceptions import BadIndexException |
| 11 | +from causalpy.custom_exceptions import ( |
| 12 | + BadIndexException, |
| 13 | + DataException, |
| 14 | + FormulaException, |
| 15 | +) |
12 | 16 | from causalpy.plot_utils import plot_xY
|
| 17 | +from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels |
13 | 18 |
|
14 | 19 | LEGEND_FONT_SIZE = 12
|
15 | 20 | az.style.use("arviz-darkgrid")
|
@@ -117,7 +122,7 @@ def __init__(
|
117 | 122 | self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")
|
118 | 123 |
|
119 | 124 | def _input_validation(self, data, treatment_time):
|
120 |
| - """Validate the input data for correctness""" |
| 125 | + """Validate the input data and model formula for correctness""" |
121 | 126 | if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
|
122 | 127 | treatment_time, pd.Timestamp
|
123 | 128 | ):
|
@@ -361,28 +366,27 @@ def __init__(
|
361 | 366 | self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
|
362 | 367 |
|
363 | 368 | def _input_validation(self):
|
364 |
| - """Validate the input data for correctness""" |
365 |
| - assert ( |
366 |
| - "post_treatment" in self.formula |
367 |
| - ), "A predictor called `post_treatment` should be in the dataframe" |
368 |
| - assert ( |
369 |
| - "post_treatment" in self.data.columns |
370 |
| - ), "Require a boolean column labelling observations which are `treated`" |
371 |
| - # Check for `unit` in the incoming dataframe. |
372 |
| - # *This is only used for plotting purposes* |
373 |
| - assert ( |
374 |
| - "unit" in self.data.columns |
375 |
| - ), """ |
376 |
| - Require a `unit` column to label unique units. |
377 |
| - This is used for plotting purposes |
378 |
| - """ |
379 |
| - # Check that `group_variable_name` is dummy coded. It should be 0 or 1 |
380 |
| - assert not set(self.data[self.group_variable_name]).difference( |
381 |
| - set([0, 1]) |
382 |
| - ), f""" |
383 |
| - The grouping variable {self.group_variable_name} should be dummy coded. |
384 |
| - Consisting of 0's and 1's only. |
385 |
| - """ |
| 369 | + """Validate the input data and model formula for correctness""" |
| 370 | + if "post_treatment" not in self.formula: |
| 371 | + raise FormulaException( |
| 372 | + "A predictor called `post_treatment` should be in the formula" |
| 373 | + ) |
| 374 | + |
| 375 | + if "post_treatment" not in self.data.columns: |
| 376 | + raise DataException( |
| 377 | + "Require a boolean column labelling observations which are `treated`" |
| 378 | + ) |
| 379 | + |
| 380 | + if "unit" not in self.data.columns: |
| 381 | + raise DataException( |
| 382 | + "Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501 |
| 383 | + ) |
| 384 | + |
| 385 | + if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False: |
| 386 | + raise DataException( |
| 387 | + f"""The grouping variable {self.group_variable_name} should be dummy |
| 388 | + coded. Consisting of 0's and 1's only.""" |
| 389 | + ) |
386 | 390 |
|
387 | 391 | def plot(self):
|
388 | 392 | """Plot the results.
|
@@ -744,16 +748,17 @@ def __init__(
|
744 | 748 | # ================================================================
|
745 | 749 |
|
746 | 750 | def _input_validation(self):
|
747 |
| - """Validate the input data for correctness""" |
| 751 | + """Validate the input data and model formula for correctness""" |
748 | 752 | # Check that `group_variable_name` has TWO levels, representing the
|
749 | 753 | # treated/untreated. But it does not matter what the actual names of
|
750 | 754 | # the levels are.
|
751 |
| - assert ( |
752 |
| - len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2 |
753 |
| - ), f""" |
754 |
| - There must be 2 levels of the grouping variable {self.group_variable_name} |
755 |
| - .I.e. the treated and untreated. |
756 |
| - """ |
| 755 | + if not _series_has_2_levels(self.data[self.group_variable_name]): |
| 756 | + raise ValueError( |
| 757 | + f""" |
| 758 | + There must be 2 levels of the grouping variable |
| 759 | + {self.group_variable_name}. I.e. the treated and untreated. |
| 760 | + """ |
| 761 | + ) |
757 | 762 |
|
758 | 763 | def plot(self):
|
759 | 764 | """Plot the results"""
|
|
0 commit comments