Skip to content

Commit c84edf4

Browse files
committed
Adding a test ton ensure an exception is raised when len(time_range) != 2
1 parent 76ef685 commit c84edf4

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

causalpy/experiments/change_point_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
from matplotlib import pyplot as plt
9797
from patsy import dmatrices
9898

99-
from causalpy.custom_exceptions import BadIndexException, ModelException
99+
from causalpy.custom_exceptions import BadIndexException, DataException, ModelException
100100
from causalpy.experiments.base import BaseExperiment
101101
from causalpy.plot_utils import get_hdi_to_df, plot_xY
102102
from causalpy.pymc_models import PyMCModel
@@ -259,7 +259,7 @@ def input_validation(self, data, time_range, model):
259259
if not hasattr(model, "set_time_range"):
260260
raise ModelException("Provided model must have a 'set_time_range' method")
261261
if time_range is not None and len(time_range) != 2:
262-
raise BadIndexException(
262+
raise DataException(
263263
"Provided time_range must be of length 2 : (start, end)"
264264
)
265265
if isinstance(data.index, pd.DatetimeIndex) and not (

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,20 @@ def test_cp_covid():
433433
)
434434
assert "Provided model must have a 'set_time_range' method" in str(exc_info.value)
435435

436+
# Assert that we correctfully raise a DataException if
437+
# - time_range is not None
438+
# - and len(time_range) is not 2
439+
with pytest.raises(cp.custom_exceptions.DataExveption) as exc_info:
440+
cp.ChangePointDetection(
441+
df,
442+
time_range=[0, 0, 0],
443+
formula="standardize(deaths) ~ 0 + t + C(month) + standardize(temp)", # noqa E501
444+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
445+
)
446+
assert "Provided time_range must be of length 2 : (start, end)" in str(
447+
exc_info.value
448+
)
449+
436450
result = cp.ChangePointDetection(
437451
df,
438452
time_range=time_range,

0 commit comments

Comments
 (0)