@@ -895,6 +895,58 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
895895 assert_allclose (regression_effect , regression_effect_expected )
896896
897897
898+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
899+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
900+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
901+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
902+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
903+ def test_build_forecast_model (rng , exog_ss_mod , exog_pymc_mod , exog_data ):
904+ data_before_build_forecast_model = {d .name : d .get_value () for d in exog_pymc_mod .data_vars }
905+
906+ scenario1 = pd .DataFrame (
907+ {
908+ "date" : pd .date_range (start = "2023-05-11" , end = "2023-05-20" , freq = "D" ),
909+ "x1" : rng .choice (2 , size = 10 , replace = True ).astype (float ),
910+ }
911+ )
912+ scenario1 .set_index ("date" , inplace = True )
913+
914+ scenario2 = pd .DataFrame (
915+ {
916+ "date" : pd .date_range (start = "2023-05-11" , end = "2023-05-20" , freq = "D" ),
917+ "x1" : np .zeros (shape = (10 ,)),
918+ }
919+ )
920+ scenario2 .set_index ("date" , inplace = True )
921+
922+ for scenario in [scenario1 , scenario2 ]:
923+ time_index = exog_ss_mod ._get_fit_time_index ()
924+ t0 , forecast_index = exog_ss_mod ._build_forecast_index (
925+ time_index = time_index ,
926+ start = exog_data .index [- 1 ],
927+ end = scenario .index [- 1 ],
928+ scenario = scenario ,
929+ )
930+
931+ test_forecast_model = exog_ss_mod ._build_forecast_model (
932+ time_index = time_index ,
933+ t0 = t0 ,
934+ forecast_index = forecast_index ,
935+ scenario = scenario ,
936+ filter_output = "smoothed" ,
937+ mvn_method = "svd" ,
938+ )
939+
940+ data_after_build_forecast_model = {
941+ d .name : d .get_value () for d in test_forecast_model .data_vars
942+ }
943+ for k in data_before_build_forecast_model .keys ():
944+ assert (
945+ data_before_build_forecast_model [k ].mean ()
946+ == data_after_build_forecast_model [k ].mean ()
947+ )
948+
949+
898950@pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
899951@pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
900952@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
0 commit comments