Skip to content

Commit bcb2512

Browse files
committed
move some tests to test_input_validation
1 parent ffc9b44 commit bcb2512

File tree

2 files changed

+46
-42
lines changed

2 files changed

+46
-42
lines changed

causalpy/tests/test_input_validation.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import pytest
33

44
import causalpy as cp
5-
from causalpy.custom_exceptions import DataException, FormulaException
5+
from causalpy.custom_exceptions import (
6+
BadIndexException,
7+
DataException,
8+
FormulaException,
9+
)
610

711
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
812

@@ -96,3 +100,44 @@ def test_did_validation_group_dummy_coded():
96100
group_variable_name="group",
97101
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
98102
)
103+
104+
105+
# Synthetic Control
106+
107+
108+
def test_sc_input_error():
109+
"""Confirm that a BadIndexException is raised treatment_time is pd.Timestamp
110+
and df.index is not pd.DatetimeIndex."""
111+
with pytest.raises(BadIndexException):
112+
df = cp.load_data("sc")
113+
treatment_time = pd.to_datetime("2016 June 24")
114+
_ = cp.pymc_experiments.SyntheticControl(
115+
df,
116+
treatment_time,
117+
formula="actual ~ 0 + a + b + c + d + e + f + g",
118+
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
119+
)
120+
121+
122+
def test_sc_brexit_input_error():
123+
"""Confirm a BadIndexException is raised if the data index is datetime and the
124+
treatment time is not pd.Timestamp."""
125+
with pytest.raises(BadIndexException):
126+
df = cp.load_data("brexit")
127+
df["Time"] = pd.to_datetime(df["Time"])
128+
df.set_index("Time", inplace=True)
129+
df = df.iloc[df.index > "2009", :]
130+
treatment_time = "2016 June 24" # NOTE This is not of type pd.Timestamp
131+
df = df.drop(["Japan", "Italy", "US", "Spain"], axis=1)
132+
target_country = "UK"
133+
all_countries = df.columns
134+
other_countries = all_countries.difference({target_country})
135+
all_countries = list(all_countries)
136+
other_countries = list(other_countries)
137+
formula = target_country + " ~ " + "0 + " + " + ".join(other_countries)
138+
_ = cp.pymc_experiments.SyntheticControl(
139+
df,
140+
treatment_time,
141+
formula=formula,
142+
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
143+
)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33

44
import causalpy as cp
5-
from causalpy.custom_exceptions import BadIndexException
65

76
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
87

@@ -196,21 +195,6 @@ def test_sc():
196195
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
197196

198197

199-
@pytest.mark.integration
200-
def test_sc_input_error():
201-
"""Confirm that a BadIndexException is raised treatment_time is pd.Timestamp
202-
and df.index is not pd.DatetimeIndex."""
203-
with pytest.raises(BadIndexException):
204-
df = cp.load_data("sc")
205-
treatment_time = pd.to_datetime("2016 June 24")
206-
_ = cp.pymc_experiments.SyntheticControl(
207-
df,
208-
treatment_time,
209-
formula="actual ~ 0 + a + b + c + d + e + f + g",
210-
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
211-
)
212-
213-
214198
@pytest.mark.integration
215199
def test_sc_brexit():
216200
df = (
@@ -239,31 +223,6 @@ def test_sc_brexit():
239223
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
240224

241225

242-
@pytest.mark.integration
243-
def test_sc_brexit_input_error():
244-
"""Confirm a BadIndexException is raised if the data index is datetime and the
245-
treatment time is not pd.Timestamp."""
246-
with pytest.raises(BadIndexException):
247-
df = cp.load_data("brexit")
248-
df["Time"] = pd.to_datetime(df["Time"])
249-
df.set_index("Time", inplace=True)
250-
df = df.iloc[df.index > "2009", :]
251-
treatment_time = "2016 June 24" # NOTE This is not of type pd.Timestamp
252-
df = df.drop(["Japan", "Italy", "US", "Spain"], axis=1)
253-
target_country = "UK"
254-
all_countries = df.columns
255-
other_countries = all_countries.difference({target_country})
256-
all_countries = list(all_countries)
257-
other_countries = list(other_countries)
258-
formula = target_country + " ~ " + "0 + " + " + ".join(other_countries)
259-
_ = cp.pymc_experiments.SyntheticControl(
260-
df,
261-
treatment_time,
262-
formula=formula,
263-
model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
264-
)
265-
266-
267226
@pytest.mark.integration
268227
def test_ancova():
269228
df = cp.load_data("anova1")

0 commit comments

Comments
 (0)