Skip to content

Commit bcf38f9

Browse files
committed
#76 add full banks example to the integration tests
1 parent 0ab2fcf commit bcf38f9

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ def test_did():
2424
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
2525

2626

27+
# TODO: set up fixture for the banks dataset
28+
29+
2730
@pytest.mark.integration
28-
def test_did_banks():
31+
def test_did_banks_simple():
2932
treatment_time = 1930.5
3033
df = (
3134
cp.load_data("banks")
@@ -60,6 +63,42 @@ def test_did_banks():
6063
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
6164

6265

66+
@pytest.mark.integration
67+
def test_did_banks_multi():
68+
treatment_time = 1930.5
69+
df = (
70+
cp.load_data("banks")
71+
.filter(items=["bib6", "bib8", "year"])
72+
.rename(columns={"bib6": "Sixth District", "bib8": "Eighth District"})
73+
.groupby("year")
74+
.median()
75+
)
76+
df.reset_index(level=0, inplace=True)
77+
df_long = pd.melt(
78+
df,
79+
id_vars=["year"],
80+
value_vars=["Sixth District", "Eighth District"],
81+
var_name="district",
82+
value_name="bib",
83+
).sort_values("year")
84+
df_long["district"] = df_long["district"].astype("category")
85+
df_long["unit"] = df_long["district"]
86+
df_long["post_treatment"] = df_long.year >= treatment_time
87+
result = cp.pymc_experiments.DifferenceInDifferences(
88+
df_long,
89+
formula="bib ~ 1 + district + year + district:post_treatment",
90+
time_variable_name="year",
91+
group_variable_name="district",
92+
treated="Sixth District",
93+
untreated="Eighth District",
94+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
95+
)
96+
assert isinstance(df, pd.DataFrame)
97+
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
98+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
99+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
100+
101+
63102
@pytest.mark.integration
64103
def test_rd():
65104
df = cp.load_data("rd")

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)