Skip to content

Commit d0eb8c8

Browse files
committed
add integration tests for the pymc examples in the docs
1 parent 37efbd8 commit d0eb8c8

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import pandas as pd
2+
import pytest
3+
4+
import causalpy as cp
5+
6+
7+
@pytest.mark.integration
8+
def test_did():
9+
df = cp.load_data("did")
10+
result = cp.pymc_experiments.DifferenceInDifferences(
11+
df,
12+
formula="y ~ 1 + group + t + treated:group",
13+
time_variable_name="t",
14+
group_variable_name="group",
15+
treated=1,
16+
untreated=0,
17+
prediction_model=cp.pymc_models.LinearRegression(),
18+
)
19+
assert isinstance(df, pd.DataFrame)
20+
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
21+
22+
23+
@pytest.mark.integration
24+
def test_did_banks():
25+
df = (
26+
cp.load_data("banks")
27+
.filter(items=["bib6", "bib8", "year"])
28+
.rename(columns={"bib6": "Sixth District", "bib8": "Eighth District"})
29+
.groupby("year")
30+
.median()
31+
)
32+
df.reset_index(level=0, inplace=True)
33+
df_long = pd.melt(
34+
df,
35+
id_vars=["year"],
36+
value_vars=["Sixth District", "Eighth District"],
37+
var_name="district",
38+
value_name="bib",
39+
).sort_values("year")
40+
df_long["district"] = df_long["district"].astype("category")
41+
df_long["unit"] = df_long["district"]
42+
df_long["treated"] = (df_long.year >= 1931) & (df_long.district == "Sixth District")
43+
result = cp.pymc_experiments.DifferenceInDifferences(
44+
df_long[df_long.year.isin([1930, 1931])],
45+
formula="bib ~ 1 + district + year + district:treated",
46+
time_variable_name="year",
47+
group_variable_name="district",
48+
treated="Sixth District",
49+
untreated="Eighth District",
50+
prediction_model=cp.pymc_models.LinearRegression(),
51+
)
52+
assert isinstance(df, pd.DataFrame)
53+
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
54+
55+
56+
@pytest.mark.integration
57+
def test_rd():
58+
df = cp.load_data("rd")
59+
result = cp.pymc_experiments.RegressionDiscontinuity(
60+
df,
61+
formula="y ~ 1 + bs(x, df=6) + treated",
62+
prediction_model=cp.pymc_models.LinearRegression(),
63+
treatment_threshold=0.5,
64+
)
65+
assert isinstance(df, pd.DataFrame)
66+
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
67+
68+
69+
@pytest.mark.integration
70+
def test_rd_drinking():
71+
df = (
72+
cp.load_data("drinking")
73+
.rename(columns={"agecell": "age"})
74+
.assign(treated=lambda df_: df_.age > 21)
75+
.dropna(axis=0)
76+
)
77+
result = cp.pymc_experiments.RegressionDiscontinuity(
78+
df,
79+
formula="all ~ 1 + age + treated",
80+
running_variable_name="age",
81+
prediction_model=cp.pymc_models.LinearRegression(),
82+
treatment_threshold=21,
83+
)
84+
assert isinstance(df, pd.DataFrame)
85+
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
86+
87+
88+
@pytest.mark.integration
89+
def test_its():
90+
df = cp.load_data("its")
91+
df["date"] = pd.to_datetime(df["date"])
92+
df.set_index("date", inplace=True)
93+
treatment_time = pd.to_datetime("2017-01-01")
94+
result = cp.pymc_experiments.SyntheticControl(
95+
df,
96+
treatment_time,
97+
formula="y ~ 1 + t + C(month)",
98+
prediction_model=cp.pymc_models.LinearRegression(),
99+
)
100+
assert isinstance(df, pd.DataFrame)
101+
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
102+
103+
104+
@pytest.mark.integration
105+
def test_its_covid():
106+
df = cp.load_data("covid")
107+
df["date"] = pd.to_datetime(df["date"])
108+
df = df.set_index("date")
109+
treatment_time = pd.to_datetime("2020-01-01")
110+
result = cp.pymc_experiments.SyntheticControl(
111+
df,
112+
treatment_time,
113+
formula="standardize(deaths) ~ 0 + standardize(t) + C(month) + standardize(temp)", # noqa E501
114+
prediction_model=cp.pymc_models.LinearRegression(),
115+
)
116+
assert isinstance(df, pd.DataFrame)
117+
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
118+
119+
120+
@pytest.mark.integration
121+
def test_sc():
122+
df = cp.load_data("sc")
123+
treatment_time = 70
124+
result = cp.pymc_experiments.SyntheticControl(
125+
df,
126+
treatment_time,
127+
formula="actual ~ 0 + a + b + c + d + e + f + g",
128+
prediction_model=cp.pymc_models.WeightedSumFitter(),
129+
)
130+
assert isinstance(df, pd.DataFrame)
131+
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
132+
133+
134+
@pytest.mark.integration
135+
def test_sc_brexit():
136+
df = cp.load_data("brexit")
137+
df["Time"] = pd.to_datetime(df["Time"])
138+
df.set_index("Time", inplace=True)
139+
df = df.iloc[df.index > "2009", :]
140+
treatment_time = pd.to_datetime("2016 June 24")
141+
df = df.drop(["Japan", "Italy", "US", "Spain"], axis=1)
142+
target_country = "UK"
143+
all_countries = df.columns
144+
other_countries = all_countries.difference({target_country})
145+
all_countries = list(all_countries)
146+
other_countries = list(other_countries)
147+
formula = target_country + " ~ " + "0 + " + " + ".join(other_countries)
148+
result = cp.pymc_experiments.SyntheticControl(
149+
df,
150+
treatment_time,
151+
formula=formula,
152+
prediction_model=cp.pymc_models.WeightedSumFitter(),
153+
)
154+
assert isinstance(df, pd.DataFrame)
155+
assert isinstance(result, cp.pymc_experiments.SyntheticControl)

0 commit comments

Comments
 (0)