@@ -366,7 +366,7 @@ def test_forecast_index(use_datetime_index):
366
366
367
367
# From start and periods
368
368
start = time_idx [- 1 ]
369
- periods = 11
369
+ periods = 10
370
370
371
371
x0_index , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = start , periods = periods )
372
372
assert start not in forecast_idx
@@ -380,7 +380,7 @@ def test_forecast_index(use_datetime_index):
380
380
381
381
assert x0_index == time_idx [start ]
382
382
assert forecast_idx .shape == (10 ,)
383
- assert (forecast_idx == time_idx [start + 1 : start + periods ]).all ()
383
+ assert (forecast_idx == time_idx [start + 1 : start + periods + 1 ]).all ()
384
384
385
385
# From scenario index
386
386
scenario = pd .DataFrame (0 , index = forecast_idx , columns = [0 , 1 , 2 ])
@@ -500,7 +500,7 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
500
500
scenario = data_type (np .zeros ((10 ,)))
501
501
502
502
scenario = ss_mod ._validate_scenario_data (scenario )
503
- t0 , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 11 )
503
+ t0 , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 10 )
504
504
scenario = ss_mod ._finalize_scenario_initialization (scenario , forecast_index = forecast_idx )
505
505
506
506
assert isinstance (scenario , pd .DataFrame )
@@ -514,11 +514,8 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
514
514
ids = ["series" , "dataframe" , "array" , "list" , "tuple" ],
515
515
)
516
516
@pytest .mark .parametrize ("use_datetime_index" , [True , False ])
517
- def test_finalize_secenario_dict (data_type , use_datetime_index ):
518
- if data_type is pd .DataFrame :
519
- # Ensure dataframes have the correct column name
520
- data_type = partial (pd .DataFrame , columns = ["column_1" ])
521
-
517
+ @pytest .mark .parametrize ("use_scenario_index" , [True , False ])
518
+ def test_finalize_secenario_dict (data_type , use_datetime_index , use_scenario_index ):
522
519
data_info = {
523
520
"a" : {"shape" : (None , 1 ), "dims" : ("time" , "features_a" )},
524
521
"b" : {"shape" : (None , 2 ), "dims" : ("time" , "features_b" )},
@@ -534,13 +531,38 @@ def test_finalize_secenario_dict(data_type, use_datetime_index):
534
531
ss_mod ._fit_coords = dict (features_a = ["column_1" ], features_b = ["column_1" , "column_2" ])
535
532
time_idx = _make_time_idx (ss_mod , use_datetime_index )
536
533
534
+ initial_index = (
535
+ pd .date_range (start = time_idx [- 1 ], periods = 10 , freq = time_idx .freq )
536
+ if use_datetime_index
537
+ else pd .RangeIndex (time_idx [- 1 ], time_idx [- 1 ] + 10 , 1 )
538
+ )
539
+
540
+ if data_type is pd .DataFrame :
541
+ # Ensure dataframes have the correct column name
542
+ data_type = partial (pd .DataFrame , columns = ["column_1" ], index = initial_index )
543
+ elif data_type is pd .Series :
544
+ data_type = partial (pd .Series , index = initial_index )
545
+
537
546
scenario = {
538
547
"a" : data_type (np .zeros ((10 ,))),
539
- "b" : pd .DataFrame (np .zeros ((10 , 2 )), columns = ss_mod ._fit_coords ["features_b" ]),
548
+ "b" : pd .DataFrame (
549
+ np .zeros ((10 , 2 )), columns = ss_mod ._fit_coords ["features_b" ], index = initial_index
550
+ ),
540
551
}
541
552
542
553
scenario = ss_mod ._validate_scenario_data (scenario )
543
- forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 10 )
554
+
555
+ if use_scenario_index and data_type not in [np .array , list , tuple ]:
556
+ t0 , forecast_idx = ss_mod ._build_forecast_index (
557
+ time_idx , scenario = scenario , periods = 10 , use_scenario_index = True
558
+ )
559
+ elif use_scenario_index and data_type in [np .array , list , tuple ]:
560
+ t0 , forecast_idx = ss_mod ._build_forecast_index (
561
+ time_idx , scenario = scenario , start = - 1 , periods = 10 , use_scenario_index = True
562
+ )
563
+ else :
564
+ t0 , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 10 )
565
+
544
566
scenario = ss_mod ._finalize_scenario_initialization (scenario , forecast_index = forecast_idx )
545
567
546
568
assert list (scenario .keys ()) == ["a" , "b" ]
@@ -678,18 +700,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog):
678
700
.assign_coords (state = ["exog[a]" , "exog[b]" , "exog[c]" ])
679
701
)
680
702
681
- print (scenario .index )
682
- print (level .coords )
683
-
684
703
regression_effect = forecast_idata .forecast_observed .isel (observed_state = 0 ) - level
685
704
regression_effect_expected = (betas * scenario_xr ).sum (dim = ["state" ])
686
705
687
706
assert_allclose (regression_effect , regression_effect_expected )
688
-
689
- # t_5 = forecast_idata.forecast_observed.isel(time=7, observed_state=0).to_numpy()
690
- # not_t_5 = (
691
- # forecast_idata.forecast_observed.isel(time=np.arange(10) != 7, observed_state=0)
692
- # .mean(dim="time")
693
- # .to_numpy()
694
- # )
695
- # assert t_5.shape == not_t_5.shape
0 commit comments