Skip to content

Commit 164368d

Browse files
committed
sample_kwargs: speed up pymc integration tests + add related asserts
1 parent 7c426c9 commit 164368d

File tree

1 file changed

+74
-8
lines changed

1 file changed

+74
-8
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import causalpy as cp
55

6+
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
7+
68

79
@pytest.mark.integration
810
def test_did():
@@ -14,10 +16,18 @@ def test_did():
1416
group_variable_name="group",
1517
treated=1,
1618
untreated=0,
17-
prediction_model=cp.pymc_models.LinearRegression(),
19+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
1820
)
1921
assert isinstance(df, pd.DataFrame)
2022
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
23+
assert (
24+
len(result.prediction_model.idata.posterior.coords["chain"])
25+
== sample_kwargs["chains"]
26+
)
27+
assert (
28+
len(result.prediction_model.idata.posterior.coords["draw"])
29+
== sample_kwargs["draws"]
30+
)
2131

2232

2333
@pytest.mark.integration
@@ -47,10 +57,18 @@ def test_did_banks():
4757
group_variable_name="district",
4858
treated="Sixth District",
4959
untreated="Eighth District",
50-
prediction_model=cp.pymc_models.LinearRegression(),
60+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
5161
)
5262
assert isinstance(df, pd.DataFrame)
5363
assert isinstance(result, cp.pymc_experiments.DifferenceInDifferences)
64+
assert (
65+
len(result.prediction_model.idata.posterior.coords["chain"])
66+
== sample_kwargs["chains"]
67+
)
68+
assert (
69+
len(result.prediction_model.idata.posterior.coords["draw"])
70+
== sample_kwargs["draws"]
71+
)
5472

5573

5674
@pytest.mark.integration
@@ -59,11 +77,19 @@ def test_rd():
5977
result = cp.pymc_experiments.RegressionDiscontinuity(
6078
df,
6179
formula="y ~ 1 + bs(x, df=6) + treated",
62-
prediction_model=cp.pymc_models.LinearRegression(),
80+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
6381
treatment_threshold=0.5,
6482
)
6583
assert isinstance(df, pd.DataFrame)
6684
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
85+
assert (
86+
len(result.prediction_model.idata.posterior.coords["chain"])
87+
== sample_kwargs["chains"]
88+
)
89+
assert (
90+
len(result.prediction_model.idata.posterior.coords["draw"])
91+
== sample_kwargs["draws"]
92+
)
6793

6894

6995
@pytest.mark.integration
@@ -77,11 +103,19 @@ def test_rd_drinking():
77103
df,
78104
formula="all ~ 1 + age + treated",
79105
running_variable_name="age",
80-
prediction_model=cp.pymc_models.LinearRegression(),
106+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
81107
treatment_threshold=21,
82108
)
83109
assert isinstance(df, pd.DataFrame)
84110
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
111+
assert (
112+
len(result.prediction_model.idata.posterior.coords["chain"])
113+
== sample_kwargs["chains"]
114+
)
115+
assert (
116+
len(result.prediction_model.idata.posterior.coords["draw"])
117+
== sample_kwargs["draws"]
118+
)
85119

86120

87121
@pytest.mark.integration
@@ -94,10 +128,18 @@ def test_its():
94128
df,
95129
treatment_time,
96130
formula="y ~ 1 + t + C(month)",
97-
prediction_model=cp.pymc_models.LinearRegression(),
131+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
98132
)
99133
assert isinstance(df, pd.DataFrame)
100134
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
135+
assert (
136+
len(result.prediction_model.idata.posterior.coords["chain"])
137+
== sample_kwargs["chains"]
138+
)
139+
assert (
140+
len(result.prediction_model.idata.posterior.coords["draw"])
141+
== sample_kwargs["draws"]
142+
)
101143

102144

103145
@pytest.mark.integration
@@ -110,10 +152,18 @@ def test_its_covid():
110152
df,
111153
treatment_time,
112154
formula="standardize(deaths) ~ 0 + standardize(t) + C(month) + standardize(temp)", # noqa E501
113-
prediction_model=cp.pymc_models.LinearRegression(),
155+
prediction_model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
114156
)
115157
assert isinstance(df, pd.DataFrame)
116158
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
159+
assert (
160+
len(result.prediction_model.idata.posterior.coords["chain"])
161+
== sample_kwargs["chains"]
162+
)
163+
assert (
164+
len(result.prediction_model.idata.posterior.coords["draw"])
165+
== sample_kwargs["draws"]
166+
)
117167

118168

119169
@pytest.mark.integration
@@ -124,10 +174,18 @@ def test_sc():
124174
df,
125175
treatment_time,
126176
formula="actual ~ 0 + a + b + c + d + e + f + g",
127-
prediction_model=cp.pymc_models.WeightedSumFitter(),
177+
prediction_model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
128178
)
129179
assert isinstance(df, pd.DataFrame)
130180
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
181+
assert (
182+
len(result.prediction_model.idata.posterior.coords["chain"])
183+
== sample_kwargs["chains"]
184+
)
185+
assert (
186+
len(result.prediction_model.idata.posterior.coords["draw"])
187+
== sample_kwargs["draws"]
188+
)
131189

132190

133191
@pytest.mark.integration
@@ -148,7 +206,15 @@ def test_sc_brexit():
148206
df,
149207
treatment_time,
150208
formula=formula,
151-
prediction_model=cp.pymc_models.WeightedSumFitter(),
209+
prediction_model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
152210
)
153211
assert isinstance(df, pd.DataFrame)
154212
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
213+
assert (
214+
len(result.prediction_model.idata.posterior.coords["chain"])
215+
== sample_kwargs["chains"]
216+
)
217+
assert (
218+
len(result.prediction_model.idata.posterior.coords["draw"])
219+
== sample_kwargs["draws"]
220+
)

0 commit comments

Comments
 (0)