@@ -507,6 +507,47 @@ def test_create_zero_dataset_error_cases(self):
507
507
):
508
508
create_zero_dataset (model , start_date , end_date , date_dim_xr )
509
509
510
+ def test_create_zero_dataset_channel_xr_includes_date_specific_error (self ):
511
+ """Ensure we hit the explicit date-dimension error when date is an allowed model dim."""
512
+
513
+ class FakeMMM_DateDim :
514
+ def __init__ (self ):
515
+ dates = pd .date_range ("2022-01-01" , "2022-01-10" , freq = "D" )
516
+ self .X = pd .DataFrame (
517
+ {
518
+ "date" : dates ,
519
+ "channel1" : np .random .rand (10 ) * 10 ,
520
+ "channel2" : np .random .rand (10 ) * 5 ,
521
+ }
522
+ )
523
+ self .date_column = "date"
524
+ self .channel_columns = ["channel1" , "channel2" ]
525
+ self .control_columns = []
526
+ # Include 'date' as a model dim so the invalid-dims check passes,
527
+ # and we can assert on the specific date-dimension error.
528
+ self .dims = ["date" ]
529
+
530
+ class FakeAdstock :
531
+ l_max = 1
532
+
533
+ self .adstock = FakeAdstock ()
534
+
535
+ model = FakeMMM_DateDim ()
536
+ start_date = "2022-02-01"
537
+ end_date = "2022-02-03"
538
+
539
+ channel_with_date = xr .Dataset (
540
+ data_vars = {
541
+ "channel1" : ("date" , np .array ([1.0 , 2.0 ])),
542
+ },
543
+ coords = {"date" : pd .date_range ("2022-01-01" , periods = 2 , freq = "D" )},
544
+ )
545
+
546
+ with pytest .raises (
547
+ ValueError , match = r"`channel_xr` must NOT include the date dimension\."
548
+ ):
549
+ create_zero_dataset (model , start_date , end_date , channel_with_date )
550
+
510
551
def test_create_zero_dataset_no_dims (self ):
511
552
"""Test create_zero_dataset with a model that has no dimensions."""
512
553
@@ -549,3 +590,89 @@ def test_create_zero_dataset_empty_date_range_error(self):
549
590
550
591
with pytest .raises (ValueError , match = "Generated date range is empty" ):
551
592
create_zero_dataset (model , start_date , end_date )
593
+
594
+ def test_create_zero_dataset_channel_xr_no_dims_all_channels (self ):
595
+ """Channel-only allocation: channel_xr is a 0-dim Dataset with per-channel scalars."""
596
+
597
+ class FakeMMM_NoDims :
598
+ def __init__ (self ):
599
+ dates = pd .date_range ("2022-01-01" , "2022-01-10" , freq = "D" )
600
+ self .X = pd .DataFrame (
601
+ {
602
+ "date" : dates ,
603
+ "channel1" : np .random .rand (10 ) * 10 ,
604
+ "channel2" : np .random .rand (10 ) * 5 ,
605
+ }
606
+ )
607
+ self .date_column = "date"
608
+ self .channel_columns = ["channel1" , "channel2" ]
609
+ self .control_columns = []
610
+ self .dims = [] # No dimensions
611
+
612
+ class FakeAdstock :
613
+ l_max = 3
614
+
615
+ self .adstock = FakeAdstock ()
616
+
617
+ model = FakeMMM_NoDims ()
618
+ start_date = "2022-02-01"
619
+ end_date = "2022-02-05"
620
+
621
+ # 0-dim Dataset: variables are channels with scalar values
622
+ channel_values = xr .Dataset (
623
+ data_vars = {
624
+ "channel1" : 100.0 ,
625
+ "channel2" : 200.0 ,
626
+ }
627
+ )
628
+
629
+ result = create_zero_dataset (model , start_date , end_date , channel_values )
630
+
631
+ # (5 + 3) days = 8 rows
632
+ assert len (result ) == 8
633
+ assert np .all (result ["channel1" ] == 100.0 )
634
+ assert np .all (result ["channel2" ] == 200.0 )
635
+
636
+ def test_create_zero_dataset_channel_xr_no_dims_missing_channel (self ):
637
+ """Channel-only allocation with missing channel var should warn and leave others at 0."""
638
+
639
+ class FakeMMM_NoDims :
640
+ def __init__ (self ):
641
+ dates = pd .date_range ("2022-01-01" , "2022-01-10" , freq = "D" )
642
+ self .X = pd .DataFrame (
643
+ {
644
+ "date" : dates ,
645
+ "channel1" : np .random .rand (10 ) * 10 ,
646
+ "channel2" : np .random .rand (10 ) * 5 ,
647
+ }
648
+ )
649
+ self .date_column = "date"
650
+ self .channel_columns = ["channel1" , "channel2" ]
651
+ self .control_columns = []
652
+ self .dims = []
653
+
654
+ class FakeAdstock :
655
+ l_max = 2
656
+
657
+ self .adstock = FakeAdstock ()
658
+
659
+ model = FakeMMM_NoDims ()
660
+ start_date = "2022-02-01"
661
+ end_date = "2022-02-03"
662
+
663
+ # Provide only one channel as scalar variable in 0-dim Dataset
664
+ channel_values = xr .Dataset (
665
+ data_vars = {
666
+ "channel1" : 50.0 ,
667
+ }
668
+ )
669
+
670
+ with pytest .warns (
671
+ UserWarning , match = "does not supply values for \\ ['channel2'\\ ]"
672
+ ):
673
+ result = create_zero_dataset (model , start_date , end_date , channel_values )
674
+
675
+ # (3 + 2) days = 5 rows
676
+ assert len (result ) == 5
677
+ assert np .all (result ["channel1" ] == 50.0 )
678
+ assert np .all (result ["channel2" ] == 0.0 )
0 commit comments