Skip to content

Commit ae7c405

Browse files
committed
add test coverage for plot method to integration tests
1 parent 6c4e43c commit ae7c405

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import pandas as pd
1616
import pytest
17+
from matplotlib import pyplot as plt
1718

1819
import causalpy as cp
1920

@@ -44,6 +45,9 @@ def test_did():
4445
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
4546
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
4647
result.summary()
48+
fig, ax = result.plot()
49+
assert isinstance(fig, plt.Figure)
50+
assert isinstance(ax, plt.Axes)
4751

4852

4953
# TODO: set up fixture for the banks dataset
@@ -100,6 +104,9 @@ def test_did_banks_simple():
100104
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
101105
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
102106
result.summary()
107+
fig, ax = result.plot()
108+
assert isinstance(fig, plt.Figure)
109+
assert isinstance(ax, plt.Axes)
103110

104111

105112
@pytest.mark.integration
@@ -152,6 +159,9 @@ def test_did_banks_multi():
152159
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
153160
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
154161
result.summary()
162+
fig, ax = result.plot()
163+
assert isinstance(fig, plt.Figure)
164+
assert isinstance(ax, plt.Axes)
155165

156166

157167
@pytest.mark.integration
@@ -178,6 +188,9 @@ def test_rd():
178188
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
179189
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
180190
result.summary()
191+
fig, ax = result.plot()
192+
assert isinstance(fig, plt.Figure)
193+
assert isinstance(ax, plt.Axes)
181194

182195

183196
@pytest.mark.integration
@@ -205,6 +218,9 @@ def test_rd_bandwidth():
205218
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
206219
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
207220
result.summary()
221+
fig, ax = result.plot()
222+
assert isinstance(fig, plt.Figure)
223+
assert isinstance(ax, plt.Axes)
208224

209225

210226
@pytest.mark.integration
@@ -235,6 +251,9 @@ def test_rd_drinking():
235251
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
236252
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
237253
result.summary()
254+
fig, ax = result.plot()
255+
assert isinstance(fig, plt.Figure)
256+
assert isinstance(ax, plt.Axes)
238257

239258

240259
def setup_regression_kink_data(kink):
@@ -288,6 +307,9 @@ def test_rkink():
288307
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
289308
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
290309
result.summary()
310+
fig, ax = result.plot()
311+
assert isinstance(fig, plt.Figure)
312+
assert isinstance(ax, plt.Axes)
291313

292314

293315
@pytest.mark.integration
@@ -315,6 +337,9 @@ def test_rkink_bandwidth():
315337
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
316338
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
317339
result.summary()
340+
fig, ax = result.plot()
341+
assert isinstance(fig, plt.Figure)
342+
assert isinstance(ax, plt.Axes)
318343

319344

320345
@pytest.mark.integration
@@ -345,6 +370,9 @@ def test_its():
345370
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
346371
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
347372
result.summary()
373+
fig, ax = result.plot()
374+
assert isinstance(fig, plt.Figure)
375+
assert isinstance(ax, plt.Axes)
348376

349377

350378
@pytest.mark.integration
@@ -376,6 +404,9 @@ def test_its_covid():
376404
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
377405
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
378406
result.summary()
407+
fig, ax = result.plot()
408+
assert isinstance(fig, plt.Figure)
409+
assert isinstance(ax, plt.Axes)
379410

380411

381412
@pytest.mark.integration
@@ -403,6 +434,9 @@ def test_sc():
403434
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
404435
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
405436
result.summary()
437+
fig, ax = result.plot()
438+
assert isinstance(fig, plt.Figure)
439+
assert isinstance(ax, plt.Axes)
406440

407441

408442
@pytest.mark.integration
@@ -442,6 +476,9 @@ def test_sc_brexit():
442476
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
443477
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
444478
result.summary()
479+
fig, ax = result.plot()
480+
assert isinstance(fig, plt.Figure)
481+
assert isinstance(ax, plt.Axes)
445482

446483

447484
@pytest.mark.integration
@@ -468,6 +505,9 @@ def test_ancova():
468505
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
469506
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
470507
result.summary()
508+
fig, ax = result.plot()
509+
assert isinstance(fig, plt.Figure)
510+
assert isinstance(ax, plt.Axes)
471511

472512

473513
@pytest.mark.integration
@@ -499,6 +539,9 @@ def test_geolift1():
499539
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
500540
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
501541
result.summary()
542+
fig, ax = result.plot()
543+
assert isinstance(fig, plt.Figure)
544+
assert isinstance(ax, plt.Axes)
502545

503546

504547
@pytest.mark.integration

causalpy/tests/test_integration_skl_examples.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import pandas as pd
1515
import pytest
16+
from matplotlib import pyplot as plt
1617
from sklearn.gaussian_process import GaussianProcessRegressor
1718
from sklearn.gaussian_process.kernels import ExpSineSquared, WhiteKernel
1819

@@ -43,6 +44,9 @@ def test_did():
4344
assert isinstance(data, pd.DataFrame)
4445
assert isinstance(result, cp.DifferenceInDifferences)
4546
result.summary()
47+
fig, ax = result.plot()
48+
assert isinstance(fig, plt.Figure)
49+
assert isinstance(ax, plt.Axes)
4650

4751

4852
@pytest.mark.integration
@@ -70,6 +74,9 @@ def test_rd_drinking():
7074
assert isinstance(df, pd.DataFrame)
7175
assert isinstance(result, cp.RegressionDiscontinuity)
7276
result.summary()
77+
fig, ax = result.plot()
78+
assert isinstance(fig, plt.Figure)
79+
assert isinstance(ax, plt.Axes)
7380

7481

7582
@pytest.mark.integration
@@ -97,6 +104,9 @@ def test_its():
97104
assert isinstance(df, pd.DataFrame)
98105
assert isinstance(result, cp.SyntheticControl)
99106
result.summary()
107+
fig, ax = result.plot()
108+
assert isinstance(fig, plt.Figure)
109+
assert isinstance(ax, plt.Axes)
100110

101111

102112
@pytest.mark.integration
@@ -119,6 +129,9 @@ def test_sc():
119129
assert isinstance(df, pd.DataFrame)
120130
assert isinstance(result, cp.SyntheticControl)
121131
result.summary()
132+
fig, ax = result.plot()
133+
assert isinstance(fig, plt.Figure)
134+
assert isinstance(ax, plt.Axes)
122135

123136

124137
@pytest.mark.integration
@@ -141,6 +154,9 @@ def test_rd_linear_main_effects():
141154
assert isinstance(data, pd.DataFrame)
142155
assert isinstance(result, cp.RegressionDiscontinuity)
143156
result.summary()
157+
fig, ax = result.plot()
158+
assert isinstance(fig, plt.Figure)
159+
assert isinstance(ax, plt.Axes)
144160

145161

146162
@pytest.mark.integration
@@ -165,6 +181,9 @@ def test_rd_linear_main_effects_bandwidth():
165181
assert isinstance(data, pd.DataFrame)
166182
assert isinstance(result, cp.RegressionDiscontinuity)
167183
result.summary()
184+
fig, ax = result.plot()
185+
assert isinstance(fig, plt.Figure)
186+
assert isinstance(ax, plt.Axes)
168187

169188

170189
@pytest.mark.integration
@@ -187,6 +206,9 @@ def test_rd_linear_with_interaction():
187206
assert isinstance(data, pd.DataFrame)
188207
assert isinstance(result, cp.RegressionDiscontinuity)
189208
result.summary()
209+
fig, ax = result.plot()
210+
assert isinstance(fig, plt.Figure)
211+
assert isinstance(ax, plt.Axes)
190212

191213

192214
@pytest.mark.integration
@@ -209,3 +231,6 @@ def test_rd_linear_with_gaussian_process():
209231
)
210232
assert isinstance(data, pd.DataFrame)
211233
assert isinstance(result, cp.RegressionDiscontinuity)
234+
fig, ax = result.plot()
235+
assert isinstance(fig, plt.Figure)
236+
assert isinstance(ax, plt.Axes)

0 commit comments

Comments
 (0)