Skip to content

Commit cd759e3

Browse files
committed
update the integration tests
1 parent d656e28 commit cd759e3

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,11 @@ def test_rd_drinking():
141141

142142
@pytest.mark.integration
143143
def test_its():
144-
df = cp.load_data("its")
145-
df["date"] = pd.to_datetime(df["date"])
146-
df.set_index("date", inplace=True)
144+
df = (
145+
cp.load_data("its")
146+
.assign(date=lambda x: pd.to_datetime(x["date"]))
147+
.set_index("date")
148+
)
147149
treatment_time = pd.to_datetime("2017-01-01")
148150
result = cp.pymc_experiments.SyntheticControl(
149151
df,
@@ -159,8 +161,11 @@ def test_its():
159161

160162
@pytest.mark.integration
161163
def test_its_covid():
162-
df = cp.load_data("covid")
163-
df["date"] = pd.to_datetime(df["date"])
164+
df = (
165+
cp.load_data("covid")
166+
.assign(date=lambda x: pd.to_datetime(x["date"]))
167+
.set_index("date")
168+
)
164169
df = df.set_index("date")
165170
treatment_time = pd.to_datetime("2020-01-01")
166171
result = cp.pymc_experiments.SyntheticControl(
@@ -193,12 +198,14 @@ def test_sc():
193198

194199
@pytest.mark.integration
195200
def test_sc_brexit():
196-
df = cp.load_data("brexit")
197-
df["Time"] = pd.to_datetime(df["Time"])
198-
df.set_index("Time", inplace=True)
199-
df = df.iloc[df.index > "2009", :]
201+
df = (
202+
cp.load_data("brexit")
203+
.assign(Time=lambda x: pd.to_datetime(x["Time"]))
204+
.set_index("Time")
205+
.loc[lambda x: x.index >= "2009-01-01"]
206+
.drop(["Japan", "Italy", "US", "Spain"], axis=1)
207+
)
200208
treatment_time = pd.to_datetime("2016 June 24")
201-
df = df.drop(["Japan", "Italy", "US", "Spain"], axis=1)
202209
target_country = "UK"
203210
all_countries = df.columns
204211
other_countries = all_countries.difference({target_country})
@@ -235,9 +242,11 @@ def test_ancova():
235242

236243
@pytest.mark.integration
237244
def test_geolift1():
238-
df = cp.load_data("geolift1")
239-
df["time"] = pd.to_datetime(df["time"])
240-
df.set_index("time", inplace=True)
245+
df = (
246+
cp.load_data("geolift1")
247+
.assign(time=lambda x: pd.to_datetime(x["time"]))
248+
.set_index("time")
249+
)
241250
treatment_time = pd.to_datetime("2022-01-01")
242251
result = cp.pymc_experiments.SyntheticControl(
243252
df,

causalpy/tests/test_integration_skl_examples.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ def test_rd_drinking():
4343

4444
@pytest.mark.integration
4545
def test_its():
46-
df = cp.load_data("its")
47-
df["date"] = pd.to_datetime(df["date"])
48-
df.set_index("date", inplace=True)
46+
df = (
47+
cp.load_data("its")
48+
.assign(date=lambda x: pd.to_datetime(x["date"]))
49+
.set_index("date")
50+
)
4951
treatment_time = pd.to_datetime("2017-01-01")
5052
result = cp.skl_experiments.SyntheticControl(
5153
df,

0 commit comments

Comments
 (0)