@@ -901,6 +901,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
901901@pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
902902@pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
903903def test_build_forecast_model (rng , exog_ss_mod , exog_pymc_mod , exog_data ):
904+ # Want to make sure this remains the same even after updating data using pm.set_data()
904905 data_before_build_forecast_model = {d .name : d .get_value () for d in exog_pymc_mod .data_vars }
905906
906907 scenario1 = pd .DataFrame (
@@ -919,6 +920,8 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data):
919920 )
920921 scenario2 .set_index ("date" , inplace = True )
921922
923+ data_after_build_forecast_model = []
924+
922925 for scenario in [scenario1 , scenario2 ]:
923926 time_index = exog_ss_mod ._get_fit_time_index ()
924927 t0 , forecast_index = exog_ss_mod ._build_forecast_index (
@@ -937,13 +940,23 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data):
937940 mvn_method = "svd" ,
938941 )
939942
940- data_after_build_forecast_model = {
941- d .name : d .get_value () for d in test_forecast_model .data_vars
942- }
943- for k in data_before_build_forecast_model .keys ():
943+ data_after_build_forecast_model .append (
944+ {d .name : d .get_value () for d in test_forecast_model .data_vars }
945+ )
946+ # Change the data here
947+ with test_forecast_model :
948+ dummy_obs_data = np .zeros ((len (forecast_index ), exog_ss_mod .k_endog ))
949+ pm .set_data (
950+ {"data_exog" : scenario } | {"data" : dummy_obs_data },
951+ coords = {"data_time" : np .arange (len (forecast_index ))},
952+ )
953+
954+ # Ensure first change in data did not affect second change ( not sure this makes sense since the forecast method will rebuild forecast model every time you call it)
955+ for k in data_before_build_forecast_model .keys ():
956+ for data_before_build_forecast_scenario_specific in data_after_build_forecast_model :
944957 assert (
945958 data_before_build_forecast_model [k ].mean ()
946- == data_after_build_forecast_model [k ].mean ()
959+ == data_before_build_forecast_scenario_specific [k ].mean ()
947960 )
948961
949962
0 commit comments