@@ -353,6 +353,7 @@ def test_its():
353
353
2. causalpy.InterruptedTimeSeries returns correct type
354
354
3. the correct number of MCMC chains exists in the posterior inference data
355
355
4. the correct number of MCMC draws exists in the posterior inference data
356
+ 5. the method get_plot_data returns a DataFrame with expected columns
356
357
"""
357
358
df = (
358
359
cp .load_data ("its" )
@@ -378,9 +379,21 @@ def test_its():
378
379
isinstance (item , plt .Axes ) for item in ax
379
380
), "ax must be a numpy.ndarray of plt.Axes"
380
381
plot_data = result .get_plot_data ()
381
- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
382
- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
383
- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
382
+ assert isinstance (plot_data , pd .DataFrame ), (
383
+ "The returned object is not a pandas DataFrame"
384
+ )
385
+ expected_columns = [
386
+ "prediction" ,
387
+ "pred_hdi_lower_94" ,
388
+ "pred_hdi_upper_94" ,
389
+ "impact" ,
390
+ "impact_hdi_lower_94" ,
391
+ "impact_hdi_upper_94" ,
392
+ ]
393
+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
394
+ f"DataFrame is missing expected columns { expected_columns } "
395
+ )
396
+
384
397
385
398
@pytest .mark .integration
386
399
def test_its_covid ():
@@ -392,6 +405,7 @@ def test_its_covid():
392
405
2. causalpy.InterruptedtimeSeries returns correct type
393
406
3. the correct number of MCMC chains exists in the posterior inference data
394
407
4. the correct number of MCMC draws exists in the posterior inference data
408
+ 5. the method get_plot_data returns a DataFrame with expected columns
395
409
"""
396
410
397
411
df = (
@@ -418,9 +432,20 @@ def test_its_covid():
418
432
isinstance (item , plt .Axes ) for item in ax
419
433
), "ax must be a numpy.ndarray of plt.Axes"
420
434
plot_data = result .get_plot_data ()
421
- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
422
- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
423
- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
435
+ assert isinstance (plot_data , pd .DataFrame ), (
436
+ "The returned object is not a pandas DataFrame"
437
+ )
438
+ expected_columns = [
439
+ "prediction" ,
440
+ "pred_hdi_lower_94" ,
441
+ "pred_hdi_upper_94" ,
442
+ "impact" ,
443
+ "impact_hdi_lower_94" ,
444
+ "impact_hdi_upper_94" ,
445
+ ]
446
+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
447
+ f"DataFrame is missing expected columns { expected_columns } "
448
+ )
424
449
425
450
426
451
@pytest .mark .integration
@@ -433,6 +458,7 @@ def test_sc():
433
458
2. causalpy.SyntheticControl returns correct type
434
459
3. the correct number of MCMC chains exists in the posterior inference data
435
460
4. the correct number of MCMC draws exists in the posterior inference data
461
+ 5. the method get_plot_data returns a DataFrame with expected columns
436
462
"""
437
463
438
464
df = cp .load_data ("sc" )
@@ -463,9 +489,21 @@ def test_sc():
463
489
isinstance (item , plt .Axes ) for item in ax
464
490
), "ax must be a numpy.ndarray of plt.Axes"
465
491
plot_data = result .get_plot_data ()
466
- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
467
- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
468
- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
492
+ assert isinstance (plot_data , pd .DataFrame ), (
493
+ "The returned object is not a pandas DataFrame"
494
+ )
495
+ expected_columns = [
496
+ "prediction" ,
497
+ "pred_hdi_lower_94" ,
498
+ "pred_hdi_upper_94" ,
499
+ "impact" ,
500
+ "impact_hdi_lower_94" ,
501
+ "impact_hdi_upper_94" ,
502
+ ]
503
+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
504
+ f"DataFrame is missing expected columns { expected_columns } "
505
+ )
506
+
469
507
470
508
@pytest .mark .integration
471
509
def test_sc_brexit ():
@@ -477,6 +515,7 @@ def test_sc_brexit():
477
515
2. causalpy.SyntheticControl returns correct type
478
516
3. the correct number of MCMC chains exists in the posterior inference data
479
517
4. the correct number of MCMC draws exists in the posterior inference data
518
+ 5. the method get_plot_data returns a DataFrame with expected columns
480
519
"""
481
520
482
521
df = (
@@ -512,9 +551,20 @@ def test_sc_brexit():
512
551
isinstance (item , plt .Axes ) for item in ax
513
552
), "ax must be a numpy.ndarray of plt.Axes"
514
553
plot_data = result .get_plot_data ()
515
- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
516
- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
517
- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
554
+ assert isinstance (plot_data , pd .DataFrame ), (
555
+ "The returned object is not a pandas DataFrame"
556
+ )
557
+ expected_columns = [
558
+ "prediction" ,
559
+ "pred_hdi_lower_94" ,
560
+ "pred_hdi_upper_94" ,
561
+ "impact" ,
562
+ "impact_hdi_lower_94" ,
563
+ "impact_hdi_upper_94" ,
564
+ ]
565
+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
566
+ f"DataFrame is missing expected columns { expected_columns } "
567
+ )
518
568
519
569
520
570
@pytest .mark .integration
0 commit comments