@@ -366,7 +366,7 @@ def test_forecast_index(use_datetime_index):
366366
367367 # From start and periods
368368 start = time_idx [- 1 ]
369- periods = 11
369+ periods = 10
370370
371371 x0_index , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = start , periods = periods )
372372 assert start not in forecast_idx
@@ -380,7 +380,7 @@ def test_forecast_index(use_datetime_index):
380380
381381 assert x0_index == time_idx [start ]
382382 assert forecast_idx .shape == (10 ,)
383- assert (forecast_idx == time_idx [start + 1 : start + periods ]).all ()
383+ assert (forecast_idx == time_idx [start + 1 : start + periods + 1 ]).all ()
384384
385385 # From scenario index
386386 scenario = pd .DataFrame (0 , index = forecast_idx , columns = [0 , 1 , 2 ])
@@ -500,7 +500,7 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
500500 scenario = data_type (np .zeros ((10 ,)))
501501
502502 scenario = ss_mod ._validate_scenario_data (scenario )
503- t0 , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 11 )
503+ t0 , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 10 )
504504 scenario = ss_mod ._finalize_scenario_initialization (scenario , forecast_index = forecast_idx )
505505
506506 assert isinstance (scenario , pd .DataFrame )
@@ -514,11 +514,8 @@ def test_finalize_scenario_single(data_type, use_datetime_index):
514514 ids = ["series" , "dataframe" , "array" , "list" , "tuple" ],
515515)
516516@pytest .mark .parametrize ("use_datetime_index" , [True , False ])
517- def test_finalize_secenario_dict (data_type , use_datetime_index ):
518- if data_type is pd .DataFrame :
519- # Ensure dataframes have the correct column name
520- data_type = partial (pd .DataFrame , columns = ["column_1" ])
521-
517+ @pytest .mark .parametrize ("use_scenario_index" , [True , False ])
518+ def test_finalize_secenario_dict (data_type , use_datetime_index , use_scenario_index ):
522519 data_info = {
523520 "a" : {"shape" : (None , 1 ), "dims" : ("time" , "features_a" )},
524521 "b" : {"shape" : (None , 2 ), "dims" : ("time" , "features_b" )},
@@ -534,13 +531,38 @@ def test_finalize_secenario_dict(data_type, use_datetime_index):
534531 ss_mod ._fit_coords = dict (features_a = ["column_1" ], features_b = ["column_1" , "column_2" ])
535532 time_idx = _make_time_idx (ss_mod , use_datetime_index )
536533
534+ initial_index = (
535+ pd .date_range (start = time_idx [- 1 ], periods = 10 , freq = time_idx .freq )
536+ if use_datetime_index
537+ else pd .RangeIndex (time_idx [- 1 ], time_idx [- 1 ] + 10 , 1 )
538+ )
539+
540+ if data_type is pd .DataFrame :
541+ # Ensure dataframes have the correct column name
542+ data_type = partial (pd .DataFrame , columns = ["column_1" ], index = initial_index )
543+ elif data_type is pd .Series :
544+ data_type = partial (pd .Series , index = initial_index )
545+
537546 scenario = {
538547 "a" : data_type (np .zeros ((10 ,))),
539- "b" : pd .DataFrame (np .zeros ((10 , 2 )), columns = ss_mod ._fit_coords ["features_b" ]),
548+ "b" : pd .DataFrame (
549+ np .zeros ((10 , 2 )), columns = ss_mod ._fit_coords ["features_b" ], index = initial_index
550+ ),
540551 }
541552
542553 scenario = ss_mod ._validate_scenario_data (scenario )
543- forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 10 )
554+
555+ if use_scenario_index and data_type not in [np .array , list , tuple ]:
556+ t0 , forecast_idx = ss_mod ._build_forecast_index (
557+ time_idx , scenario = scenario , periods = 10 , use_scenario_index = True
558+ )
559+ elif use_scenario_index and data_type in [np .array , list , tuple ]:
560+ t0 , forecast_idx = ss_mod ._build_forecast_index (
561+ time_idx , scenario = scenario , start = - 1 , periods = 10 , use_scenario_index = True
562+ )
563+ else :
564+ t0 , forecast_idx = ss_mod ._build_forecast_index (time_idx , start = time_idx [- 1 ], periods = 10 )
565+
544566 scenario = ss_mod ._finalize_scenario_initialization (scenario , forecast_index = forecast_idx )
545567
546568 assert list (scenario .keys ()) == ["a" , "b" ]
@@ -678,18 +700,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog):
678700 .assign_coords (state = ["exog[a]" , "exog[b]" , "exog[c]" ])
679701 )
680702
681- print (scenario .index )
682- print (level .coords )
683-
684703 regression_effect = forecast_idata .forecast_observed .isel (observed_state = 0 ) - level
685704 regression_effect_expected = (betas * scenario_xr ).sum (dim = ["state" ])
686705
687706 assert_allclose (regression_effect , regression_effect_expected )
688-
689- # t_5 = forecast_idata.forecast_observed.isel(time=7, observed_state=0).to_numpy()
690- # not_t_5 = (
691- # forecast_idata.forecast_observed.isel(time=np.arange(10) != 7, observed_state=0)
692- # .mean(dim="time")
693- # .to_numpy()
694- # )
695- # assert t_5.shape == not_t_5.shape
0 commit comments