Skip to content

Commit b1681da

Browse files
committed
updating example notebook
1 parent 08c520c commit b1681da

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

causalpy/custom_exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@ class DataException(Exception):
3737

3838
def __init__(self, message: str):
3939
self.message = message
40+
41+
42+
class ModelException(Exception):
43+
"""Exception raised given when there is some error in user-provided model"""
44+
45+
def __init__(self, message: str):
46+
self.message = message

causalpy/experiments/interrupted_time_series.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from patsy import build_design_matrices, dmatrices
2525
from sklearn.base import RegressorMixin
2626

27-
from causalpy.custom_exceptions import BadIndexException
27+
from causalpy.custom_exceptions import BadIndexException, ModelException
2828
from causalpy.experiments.base import BaseExperiment
2929
from causalpy.plot_utils import get_hdi_to_df, plot_xY
3030
from causalpy.pymc_models import PyMCModel
@@ -83,9 +83,7 @@ def __init__(
8383
**kwargs,
8484
) -> None:
8585
super().__init__(model=model)
86-
# input validation TODO : for the moment only valid for given treatment time
87-
if treatment_time is not None or not isinstance(treatment_time, tuple):
88-
self.input_validation(data, treatment_time)
86+
self.input_validation(data, treatment_time, model)
8987

9088
self.treatment_time = treatment_time
9189
# set experiment type - usually done in subclasses
@@ -155,15 +153,23 @@ def __init__(
155153
self.post_impact
156154
)
157155

158-
def input_validation(self, data, treatment_time):
156+
def input_validation(self, data, treatment_time, model):
159157
"""Validate the input data and model formula for correctness"""
160-
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
158+
if treatment_time is None and not hasattr(model, "set_time_range"):
159+
raise ModelException(
160+
"If treatment_time is None, provided model must have a 'set_time_range' method"
161+
)
162+
elif isinstance(treatment_time, tuple) and not hasattr(model, "set_time_range"):
163+
raise ModelException(
164+
"If treatment_time is a tuple, provided model must have a 'set_time_range' method"
165+
)
166+
elif isinstance(data.index, pd.DatetimeIndex) and not isinstance(
161167
treatment_time, pd.Timestamp
162168
):
163169
raise BadIndexException(
164170
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
165171
)
166-
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
172+
elif not isinstance(data.index, pd.DatetimeIndex) and isinstance(
167173
treatment_time, pd.Timestamp
168174
):
169175
raise BadIndexException(

0 commit comments

Comments
 (0)