|
24 | 24 | from patsy import build_design_matrices, dmatrices |
25 | 25 | from sklearn.base import RegressorMixin |
26 | 26 |
|
27 | | -from causalpy.custom_exceptions import BadIndexException |
| 27 | +from causalpy.custom_exceptions import BadIndexException, ModelException |
28 | 28 | from causalpy.experiments.base import BaseExperiment |
29 | 29 | from causalpy.plot_utils import get_hdi_to_df, plot_xY |
30 | 30 | from causalpy.pymc_models import PyMCModel |
@@ -83,9 +83,7 @@ def __init__( |
83 | 83 | **kwargs, |
84 | 84 | ) -> None: |
85 | 85 | super().__init__(model=model) |
86 | | - # input validation TODO : for the moment only valid for given treatment time |
87 | | - if treatment_time is not None or not isinstance(treatment_time, tuple): |
88 | | - self.input_validation(data, treatment_time) |
| 86 | + self.input_validation(data, treatment_time, model) |
89 | 87 |
|
90 | 88 | self.treatment_time = treatment_time |
91 | 89 | # set experiment type - usually done in subclasses |
@@ -155,15 +153,23 @@ def __init__( |
155 | 153 | self.post_impact |
156 | 154 | ) |
157 | 155 |
|
158 | | - def input_validation(self, data, treatment_time): |
| 156 | + def input_validation(self, data, treatment_time, model): |
159 | 157 | """Validate the input data and model formula for correctness""" |
160 | | - if isinstance(data.index, pd.DatetimeIndex) and not isinstance( |
| 158 | + if treatment_time is None and not hasattr(model, "set_time_range"): |
| 159 | + raise ModelException( |
| 160 | + "If treatment_time is None, provided model must have a 'set_time_range' method" |
| 161 | + ) |
| 162 | + elif isinstance(treatment_time, tuple) and not hasattr(model, "set_time_range"): |
| 163 | + raise ModelException( |
| 164 | + "If treatment_time is a tuple, provided model must have a 'set_time_range' method" |
| 165 | + ) |
| 166 | + elif isinstance(data.index, pd.DatetimeIndex) and not isinstance( |
161 | 167 | treatment_time, pd.Timestamp |
162 | 168 | ): |
163 | 169 | raise BadIndexException( |
164 | 170 | "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp." |
165 | 171 | ) |
166 | | - if not isinstance(data.index, pd.DatetimeIndex) and isinstance( |
| 172 | + elif not isinstance(data.index, pd.DatetimeIndex) and isinstance( |
167 | 173 | treatment_time, pd.Timestamp |
168 | 174 | ): |
169 | 175 | raise BadIndexException( |
|
0 commit comments