Skip to content

Commit 8635274

Browse files
committed
made change to test_build_forecast_model() to ensure data is replaced with pm.set_data method
1 parent d41a109 commit 8635274

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

tests/statespace/core/test_statespace.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
901901
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
902902
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
903903
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data):
904+
# Want to make sure this remains the same even after updating data using pm.set_data()
904905
data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars}
905906

906907
scenario1 = pd.DataFrame(
@@ -919,6 +920,8 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data):
919920
)
920921
scenario2.set_index("date", inplace=True)
921922

923+
data_after_build_forecast_model = []
924+
922925
for scenario in [scenario1, scenario2]:
923926
time_index = exog_ss_mod._get_fit_time_index()
924927
t0, forecast_index = exog_ss_mod._build_forecast_index(
@@ -937,13 +940,23 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data):
937940
mvn_method="svd",
938941
)
939942

940-
data_after_build_forecast_model = {
941-
d.name: d.get_value() for d in test_forecast_model.data_vars
942-
}
943-
for k in data_before_build_forecast_model.keys():
943+
data_after_build_forecast_model.append(
944+
{d.name: d.get_value() for d in test_forecast_model.data_vars}
945+
)
946+
# Change the data here
947+
with test_forecast_model:
948+
dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
949+
pm.set_data(
950+
{"data_exog": scenario} | {"data": dummy_obs_data},
951+
coords={"data_time": np.arange(len(forecast_index))},
952+
)
953+
954+
# Ensure first change in data did not affect second change ( not sure this makes sense since the forecast method will rebuild forecast model every time you call it)
955+
for k in data_before_build_forecast_model.keys():
956+
for data_before_build_forecast_scenario_specific in data_after_build_forecast_model:
944957
assert (
945958
data_before_build_forecast_model[k].mean()
946-
== data_after_build_forecast_model[k].mean()
959+
== data_before_build_forecast_scenario_specific[k].mean()
947960
)
948961

949962

0 commit comments

Comments
 (0)