|
2 | 2 | import pytest
|
3 | 3 |
|
4 | 4 | import causalpy as cp
|
5 |
| -from causalpy.custom_exceptions import BadIndexException |
6 | 5 |
|
7 | 6 | sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
|
8 | 7 |
|
@@ -196,21 +195,6 @@ def test_sc():
|
196 | 195 | assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
|
197 | 196 |
|
198 | 197 |
|
199 |
| -@pytest.mark.integration |
200 |
| -def test_sc_input_error(): |
201 |
| - """Confirm that a BadIndexException is raised treatment_time is pd.Timestamp |
202 |
| - and df.index is not pd.DatetimeIndex.""" |
203 |
| - with pytest.raises(BadIndexException): |
204 |
| - df = cp.load_data("sc") |
205 |
| - treatment_time = pd.to_datetime("2016 June 24") |
206 |
| - _ = cp.pymc_experiments.SyntheticControl( |
207 |
| - df, |
208 |
| - treatment_time, |
209 |
| - formula="actual ~ 0 + a + b + c + d + e + f + g", |
210 |
| - model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs), |
211 |
| - ) |
212 |
| - |
213 |
| - |
214 | 198 | @pytest.mark.integration
|
215 | 199 | def test_sc_brexit():
|
216 | 200 | df = (
|
@@ -239,31 +223,6 @@ def test_sc_brexit():
|
239 | 223 | assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
|
240 | 224 |
|
241 | 225 |
|
242 |
| -@pytest.mark.integration |
243 |
| -def test_sc_brexit_input_error(): |
244 |
| - """Confirm a BadIndexException is raised if the data index is datetime and the |
245 |
| - treatment time is not pd.Timestamp.""" |
246 |
| - with pytest.raises(BadIndexException): |
247 |
| - df = cp.load_data("brexit") |
248 |
| - df["Time"] = pd.to_datetime(df["Time"]) |
249 |
| - df.set_index("Time", inplace=True) |
250 |
| - df = df.iloc[df.index > "2009", :] |
251 |
| - treatment_time = "2016 June 24" # NOTE This is not of type pd.Timestamp |
252 |
| - df = df.drop(["Japan", "Italy", "US", "Spain"], axis=1) |
253 |
| - target_country = "UK" |
254 |
| - all_countries = df.columns |
255 |
| - other_countries = all_countries.difference({target_country}) |
256 |
| - all_countries = list(all_countries) |
257 |
| - other_countries = list(other_countries) |
258 |
| - formula = target_country + " ~ " + "0 + " + " + ".join(other_countries) |
259 |
| - _ = cp.pymc_experiments.SyntheticControl( |
260 |
| - df, |
261 |
| - treatment_time, |
262 |
| - formula=formula, |
263 |
| - model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs), |
264 |
| - ) |
265 |
| - |
266 |
| - |
267 | 226 | @pytest.mark.integration
|
268 | 227 | def test_ancova():
|
269 | 228 | df = cp.load_data("anova1")
|
|
0 commit comments