Skip to content

Commit b60cdee

Browse files
committed
swap assert for a custom BadIndexException defined in a custom_exceptions module
1 parent 1ba64fb commit b60cdee

File tree

4 files changed

+29
-17
lines changed

4 files changed

+29
-17
lines changed

causalpy/custom_exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class BadIndexException(Exception):
2+
"""Custom exception used when we have a mismatch in types between the dataframe
3+
index and an event, typically a treatment or intervention."""
4+
5+
def __init__(self, message):
6+
self.message = message

causalpy/pymc_experiments.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import xarray as xr
99
from patsy import build_design_matrices, dmatrices
1010

11+
from causalpy.custom_exceptions import BadIndexException
1112
from causalpy.plot_utils import plot_xY
1213

1314
LEGEND_FONT_SIZE = 12
@@ -117,14 +118,18 @@ def __init__(
117118

118119
def _input_validation(self, data, treatment_time):
119120
"""Validate the input data for correctness"""
120-
if isinstance(data.index, pd.DatetimeIndex):
121-
assert isinstance(
122-
treatment_time, pd.Timestamp
123-
), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
124-
else:
125-
assert (
126-
isinstance(treatment_time, pd.Timestamp) is False
127-
), "If treatment_time is pd.Timestamp, this only makese sense if data.index is DatetimeIndex." # noqa: E501
121+
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
122+
treatment_time, pd.Timestamp
123+
):
124+
raise BadIndexException(
125+
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
126+
)
127+
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
128+
treatment_time, pd.Timestamp
129+
):
130+
raise BadIndexException(
131+
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
132+
)
128133

129134
def plot(self):
130135

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
import causalpy as cp
5+
from causalpy.custom_exceptions import BadIndexException
56

67
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
78

@@ -193,9 +194,9 @@ def test_sc():
193194

194195
@pytest.mark.integration
195196
def test_sc_input_error():
196-
"""Test that an error is raised if the data index is not datetime and the
197-
treatment time is pd.Timestamp."""
198-
with pytest.raises(AssertionError):
197+
"""Confirm that a BadIndexException is raised treatment_time is pd.Timestamp
198+
and df.index is not pd.DatetimeIndex."""
199+
with pytest.raises(BadIndexException):
199200
df = cp.load_data("sc")
200201
treatment_time = pd.to_datetime("2016 June 24")
201202
_ = cp.pymc_experiments.SyntheticControl(
@@ -234,9 +235,9 @@ def test_sc_brexit():
234235

235236
@pytest.mark.integration
236237
def test_sc_brexit_input_error():
237-
"""Test that an error is raised if the data index is datetime and the treatment time
238-
is not pd.Timestamp."""
239-
with pytest.raises(AssertionError):
238+
"""Confirm a BadIndexException is raised if the data index is datetime and the
239+
treatment time is not pd.Timestamp."""
240+
with pytest.raises(BadIndexException):
240241
df = cp.load_data("brexit")
241242
df["Time"] = pd.to_datetime(df["Time"])
242243
df.set_index("Time", inplace=True)

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)