@@ -495,3 +495,124 @@ def test_saturation_curves_scatter_deprecation_warning(mock_suite_with_constant_
495
495
assert isinstance (fig , Figure )
496
496
assert isinstance (axes , np .ndarray )
497
497
assert all (isinstance (ax , Axes ) for ax in axes .flat )
498
+
499
+
500
+ @pytest .fixture (scope = "module" )
501
+ def mock_idata_with_constant_data_single_dim () -> az .InferenceData :
502
+ """Mock InferenceData where channel_data has only ('date','channel') dims."""
503
+ seed = sum (map (ord , "Saturation single-dim tests" ))
504
+ rng = np .random .default_rng (seed )
505
+ normal = rng .normal
506
+
507
+ dates = pd .date_range ("2025-01-01" , periods = 12 , freq = "W-MON" )
508
+ channels = ["channel_1" , "channel_2" , "channel_3" ]
509
+
510
+ posterior = xr .Dataset (
511
+ {
512
+ "channel_contribution" : xr .DataArray (
513
+ normal (size = (2 , 10 , 12 , 3 )),
514
+ dims = ("chain" , "draw" , "date" , "channel" ),
515
+ coords = {
516
+ "chain" : np .arange (2 ),
517
+ "draw" : np .arange (10 ),
518
+ "date" : dates ,
519
+ "channel" : channels ,
520
+ },
521
+ ),
522
+ "channel_contribution_original_scale" : xr .DataArray (
523
+ normal (size = (2 , 10 , 12 , 3 )) * 100.0 ,
524
+ dims = ("chain" , "draw" , "date" , "channel" ),
525
+ coords = {
526
+ "chain" : np .arange (2 ),
527
+ "draw" : np .arange (10 ),
528
+ "date" : dates ,
529
+ "channel" : channels ,
530
+ },
531
+ ),
532
+ }
533
+ )
534
+
535
+ constant_data = xr .Dataset (
536
+ {
537
+ "channel_data" : xr .DataArray (
538
+ rng .uniform (0 , 10 , size = (12 , 3 )),
539
+ dims = ("date" , "channel" ),
540
+ coords = {"date" : dates , "channel" : channels },
541
+ ),
542
+ "channel_scale" : xr .DataArray (
543
+ [100.0 , 150.0 , 200.0 ], dims = ("channel" ,), coords = {"channel" : channels }
544
+ ),
545
+ "target_scale" : xr .DataArray (
546
+ [1000.0 ], dims = "target" , coords = {"target" : ["y" ]}
547
+ ),
548
+ }
549
+ )
550
+
551
+ return az .InferenceData (posterior = posterior , constant_data = constant_data )
552
+
553
+
554
+ @pytest .fixture (scope = "module" )
555
+ def mock_suite_with_constant_data_single_dim (mock_idata_with_constant_data_single_dim ):
556
+ return MMMPlotSuite (idata = mock_idata_with_constant_data_single_dim )
557
+
558
+
559
+ @pytest .fixture (scope = "module" )
560
+ def mock_saturation_curve_single_dim () -> xr .DataArray :
561
+ """Saturation curve with dims ('chain','draw','channel','x')."""
562
+ seed = sum (map (ord , "Saturation curve single-dim" ))
563
+ rng = np .random .default_rng (seed )
564
+ x_values = np .linspace (0 , 1 , 50 )
565
+ channels = ["channel_1" , "channel_2" , "channel_3" ]
566
+
567
+ # shape: (chains=2, draws=10, channel=3, x=50)
568
+ curve_array = np .empty ((2 , 10 , len (channels ), len (x_values )))
569
+ for ci in range (2 ):
570
+ for di in range (10 ):
571
+ for c in range (len (channels )):
572
+ curve_array [ci , di , c , :] = x_values / (1 + x_values ) + rng .normal (
573
+ 0 , 0.02 , size = x_values .shape
574
+ )
575
+
576
+ return xr .DataArray (
577
+ curve_array ,
578
+ dims = ("chain" , "draw" , "channel" , "x" ),
579
+ coords = {
580
+ "chain" : np .arange (2 ),
581
+ "draw" : np .arange (10 ),
582
+ "channel" : channels ,
583
+ "x" : x_values ,
584
+ },
585
+ name = "saturation_curve" ,
586
+ )
587
+
588
+
589
+ def test_saturation_curves_single_dim_axes_shape (
590
+ mock_suite_with_constant_data_single_dim , mock_saturation_curve_single_dim
591
+ ):
592
+ """When there are no extra dims, columns should default to 1 (no ncols=0)."""
593
+ fig , axes = mock_suite_with_constant_data_single_dim .saturation_curves (
594
+ curve = mock_saturation_curve_single_dim , n_samples = 3
595
+ )
596
+
597
+ assert isinstance (fig , Figure )
598
+ assert isinstance (axes , np .ndarray )
599
+ # Expect (n_channels, 1)
600
+ assert axes .shape [1 ] == 1
601
+ assert axes .shape [0 ] == mock_saturation_curve_single_dim .sizes ["channel" ]
602
+
603
+
604
+ def test_saturation_curves_multi_dim_axes_shape (
605
+ mock_suite_with_constant_data , mock_saturation_curve
606
+ ):
607
+ """With an extra dim (e.g., 'country'), expect (n_channels, n_countries)."""
608
+ fig , axes = mock_suite_with_constant_data .saturation_curves (
609
+ curve = mock_saturation_curve , n_samples = 2
610
+ )
611
+
612
+ assert isinstance (fig , Figure )
613
+ assert isinstance (axes , np .ndarray )
614
+ n_channels = mock_saturation_curve .sizes ["channel" ]
615
+ n_countries = mock_suite_with_constant_data .idata .constant_data .channel_data .sizes [
616
+ "country"
617
+ ]
618
+ assert axes .shape == (n_channels , n_countries )
0 commit comments