Skip to content

Commit 64df14a

Browse files
committed
add test to cover the opposite situation
1 parent 0247dae commit 64df14a

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

causalpy/pymc_experiments.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ def __init__(
6060
**kwargs,
6161
) -> None:
6262
super().__init__(model=model, **kwargs)
63+
64+
# Input validation
65+
if isinstance(data.index, pd.DatetimeIndex):
66+
assert isinstance(
67+
treatment_time, pd.Timestamp
68+
), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
69+
else:
70+
assert (
71+
isinstance(treatment_time, pd.Timestamp) is False
72+
), "If treatment_time is pd.Timestamp, this only makese sense if data.index is DatetimeIndex." # noqa: E501
73+
6374
self.treatment_time = treatment_time
6475
# split data in to pre and post intervention
6576
self.datapre = data[data.index <= self.treatment_time]
@@ -81,12 +92,6 @@ def __init__(
8192
self.post_X = np.asarray(new_x)
8293
self.post_y = np.asarray(new_y)
8394

84-
# Input validation
85-
if isinstance(data.index, pd.DatetimeIndex):
86-
assert isinstance(
87-
treatment_time, pd.Timestamp
88-
), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
89-
9095
# DEVIATION FROM SKL EXPERIMENT CODE =============================
9196
# fit the model to the observed (pre-intervention) data
9297
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,21 @@ def test_sc():
191191
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
192192

193193

194+
@pytest.mark.integration
195+
def test_sc_input_error():
196+
"""Test that an error is raised if the data index is not datetime and the
197+
treatment time is pd.Timestamp."""
198+
with pytest.raises(AssertionError):
199+
df = cp.load_data("sc")
200+
treatment_time = pd.to_datetime("2016 June 24")
201+
_ = cp.pymc_experiments.SyntheticControl(
202+
df,
203+
treatment_time,
204+
formula="actual ~ 0 + a + b + c + d + e + f + g",
205+
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
206+
)
207+
208+
194209
@pytest.mark.integration
195210
def test_sc_brexit():
196211
df = cp.load_data("brexit")

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)