14
14
import numpy as np
15
15
import pandas as pd
16
16
import pytest
17
+ from matplotlib import pyplot as plt
17
18
18
19
import causalpy as cp
19
20
@@ -44,6 +45,9 @@ def test_did():
44
45
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
45
46
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
46
47
result .summary ()
48
+ fig , ax = result .plot ()
49
+ assert isinstance (fig , plt .Figure )
50
+ assert isinstance (ax , plt .Axes )
47
51
48
52
49
53
# TODO: set up fixture for the banks dataset
@@ -100,6 +104,9 @@ def test_did_banks_simple():
100
104
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
101
105
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
102
106
result .summary ()
107
+ fig , ax = result .plot ()
108
+ assert isinstance (fig , plt .Figure )
109
+ assert isinstance (ax , plt .Axes )
103
110
104
111
105
112
@pytest .mark .integration
@@ -152,6 +159,9 @@ def test_did_banks_multi():
152
159
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
153
160
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
154
161
result .summary ()
162
+ fig , ax = result .plot ()
163
+ assert isinstance (fig , plt .Figure )
164
+ assert isinstance (ax , plt .Axes )
155
165
156
166
157
167
@pytest .mark .integration
@@ -178,6 +188,9 @@ def test_rd():
178
188
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
179
189
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
180
190
result .summary ()
191
+ fig , ax = result .plot ()
192
+ assert isinstance (fig , plt .Figure )
193
+ assert isinstance (ax , plt .Axes )
181
194
182
195
183
196
@pytest .mark .integration
@@ -205,6 +218,9 @@ def test_rd_bandwidth():
205
218
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
206
219
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
207
220
result .summary ()
221
+ fig , ax = result .plot ()
222
+ assert isinstance (fig , plt .Figure )
223
+ assert isinstance (ax , plt .Axes )
208
224
209
225
210
226
@pytest .mark .integration
@@ -235,6 +251,9 @@ def test_rd_drinking():
235
251
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
236
252
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
237
253
result .summary ()
254
+ fig , ax = result .plot ()
255
+ assert isinstance (fig , plt .Figure )
256
+ assert isinstance (ax , plt .Axes )
238
257
239
258
240
259
def setup_regression_kink_data (kink ):
@@ -288,6 +307,9 @@ def test_rkink():
288
307
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
289
308
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
290
309
result .summary ()
310
+ fig , ax = result .plot ()
311
+ assert isinstance (fig , plt .Figure )
312
+ assert isinstance (ax , plt .Axes )
291
313
292
314
293
315
@pytest .mark .integration
@@ -315,6 +337,9 @@ def test_rkink_bandwidth():
315
337
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
316
338
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
317
339
result .summary ()
340
+ fig , ax = result .plot ()
341
+ assert isinstance (fig , plt .Figure )
342
+ assert isinstance (ax , plt .Axes )
318
343
319
344
320
345
@pytest .mark .integration
@@ -345,6 +370,9 @@ def test_its():
345
370
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
346
371
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
347
372
result .summary ()
373
+ fig , ax = result .plot ()
374
+ assert isinstance (fig , plt .Figure )
375
+ assert isinstance (ax , plt .Axes )
348
376
349
377
350
378
@pytest .mark .integration
@@ -376,6 +404,9 @@ def test_its_covid():
376
404
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
377
405
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
378
406
result .summary ()
407
+ fig , ax = result .plot ()
408
+ assert isinstance (fig , plt .Figure )
409
+ assert isinstance (ax , plt .Axes )
379
410
380
411
381
412
@pytest .mark .integration
@@ -403,6 +434,9 @@ def test_sc():
403
434
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
404
435
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
405
436
result .summary ()
437
+ fig , ax = result .plot ()
438
+ assert isinstance (fig , plt .Figure )
439
+ assert isinstance (ax , plt .Axes )
406
440
407
441
408
442
@pytest .mark .integration
@@ -442,6 +476,9 @@ def test_sc_brexit():
442
476
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
443
477
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
444
478
result .summary ()
479
+ fig , ax = result .plot ()
480
+ assert isinstance (fig , plt .Figure )
481
+ assert isinstance (ax , plt .Axes )
445
482
446
483
447
484
@pytest .mark .integration
@@ -468,6 +505,9 @@ def test_ancova():
468
505
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
469
506
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
470
507
result .summary ()
508
+ fig , ax = result .plot ()
509
+ assert isinstance (fig , plt .Figure )
510
+ assert isinstance (ax , plt .Axes )
471
511
472
512
473
513
@pytest .mark .integration
@@ -499,6 +539,9 @@ def test_geolift1():
499
539
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
500
540
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
501
541
result .summary ()
542
+ fig , ax = result .plot ()
543
+ assert isinstance (fig , plt .Figure )
544
+ assert isinstance (ax , plt .Axes )
502
545
503
546
504
547
@pytest .mark .integration
0 commit comments