@@ -349,7 +349,7 @@ def test_its():
349
349
350
350
Loads data and checks:
351
351
1. data is a dataframe
352
- 2. pymc_experiments.SyntheticControl returns correct type
352
+ 2. pymc_experiments.InterruptedTimeSeries returns correct type
353
353
3. the correct number of MCMC chains exists in the posterior inference data
354
354
4. the correct number of MCMC draws exists in the posterior inference data
355
355
"""
@@ -359,20 +359,23 @@ def test_its():
359
359
.set_index ("date" )
360
360
)
361
361
treatment_time = pd .to_datetime ("2017-01-01" )
362
- result = cp .SyntheticControl (
362
+ result = cp .InterruptedTimeSeries (
363
363
df ,
364
364
treatment_time ,
365
365
formula = "y ~ 1 + t + C(month)" ,
366
366
model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
367
367
)
368
368
assert isinstance (df , pd .DataFrame )
369
- assert isinstance (result , cp .SyntheticControl )
369
+ assert isinstance (result , cp .InterruptedTimeSeries )
370
370
assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
371
371
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
372
372
result .summary ()
373
373
fig , ax = result .plot ()
374
374
assert isinstance (fig , plt .Figure )
375
- assert isinstance (ax , plt .Axes )
375
+ # For multi-panel plots, ax should be an array of axes
376
+ assert isinstance (ax , np .ndarray ) and all (
377
+ isinstance (item , plt .Axes ) for item in ax
378
+ ), "ax must be a numpy.ndarray of plt.Axes"
376
379
377
380
378
381
@pytest .mark .integration
@@ -406,7 +409,10 @@ def test_its_covid():
406
409
result .summary ()
407
410
fig , ax = result .plot ()
408
411
assert isinstance (fig , plt .Figure )
409
- assert isinstance (ax , plt .Axes )
412
+ # For multi-panel plots, ax should be an array of axes
413
+ assert isinstance (ax , np .ndarray ) and all (
414
+ isinstance (item , plt .Axes ) for item in ax
415
+ ), "ax must be a numpy.ndarray of plt.Axes"
410
416
411
417
412
418
@pytest .mark .integration
@@ -436,7 +442,10 @@ def test_sc():
436
442
result .summary ()
437
443
fig , ax = result .plot ()
438
444
assert isinstance (fig , plt .Figure )
439
- assert isinstance (ax , plt .Axes )
445
+ # For multi-panel plots, ax should be an array of axes
446
+ assert isinstance (ax , np .ndarray ) and all (
447
+ isinstance (item , plt .Axes ) for item in ax
448
+ ), "ax must be a numpy.ndarray of plt.Axes"
440
449
441
450
442
451
@pytest .mark .integration
@@ -478,7 +487,10 @@ def test_sc_brexit():
478
487
result .summary ()
479
488
fig , ax = result .plot ()
480
489
assert isinstance (fig , plt .Figure )
481
- assert isinstance (ax , plt .Axes )
490
+ # For multi-panel plots, ax should be an array of axes
491
+ assert isinstance (ax , np .ndarray ) and all (
492
+ isinstance (item , plt .Axes ) for item in ax
493
+ ), "ax must be a numpy.ndarray of plt.Axes"
482
494
483
495
484
496
@pytest .mark .integration
0 commit comments