Skip to content

Commit 4119a0d

Browse files
committed
added additional checks to test_build_forecast_model
1 parent 8635274 commit 4119a0d

File tree

1 file changed

+67
-43
lines changed

1 file changed

+67
-43
lines changed

tests/statespace/core/test_statespace.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import pytest
1010

1111
from numpy.testing import assert_allclose
12+
from pytensor.compile import SharedVariable
13+
from pytensor.graph.basic import graph_inputs
1214

1315
from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
1416
from pymc_extras.statespace.models import structural as st
@@ -170,7 +172,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
170172
)
171173
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])
172174

173-
exog_ss_mod.build_statespace_graph(exog_data["y"])
175+
exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True)
174176

175177
return struct_model
176178

@@ -900,64 +902,86 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
900902
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
901903
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
902904
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
903-
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data):
904-
# Want to make sure this remains the same even after updating data using pm.set_data()
905+
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog):
905906
data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars}
906907

907-
scenario1 = pd.DataFrame(
908+
scenario = pd.DataFrame(
908909
{
909910
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
910911
"x1": rng.choice(2, size=10, replace=True).astype(float),
911912
}
912913
)
913-
scenario1.set_index("date", inplace=True)
914+
scenario.set_index("date", inplace=True)
914915

915-
scenario2 = pd.DataFrame(
916-
{
917-
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
918-
"x1": np.zeros(shape=(10,)),
919-
}
916+
time_index = exog_ss_mod._get_fit_time_index()
917+
t0, forecast_index = exog_ss_mod._build_forecast_index(
918+
time_index=time_index,
919+
start=exog_data.index[-1],
920+
end=scenario.index[-1],
921+
scenario=scenario,
920922
)
921-
scenario2.set_index("date", inplace=True)
922923

923-
data_after_build_forecast_model = []
924+
test_forecast_model = exog_ss_mod._build_forecast_model(
925+
time_index=time_index,
926+
t0=t0,
927+
forecast_index=forecast_index,
928+
scenario=scenario,
929+
filter_output="predicted",
930+
mvn_method="svd",
931+
)
924932

925-
for scenario in [scenario1, scenario2]:
926-
time_index = exog_ss_mod._get_fit_time_index()
927-
t0, forecast_index = exog_ss_mod._build_forecast_index(
928-
time_index=time_index,
929-
start=exog_data.index[-1],
930-
end=scenario.index[-1],
931-
scenario=scenario,
933+
frozen_shared_inputs = [
934+
inpt
935+
for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice])
936+
if isinstance(inpt, SharedVariable)
937+
and not isinstance(inpt.get_value(), np.random.Generator)
938+
]
939+
940+
assert (
941+
len(frozen_shared_inputs) == 0
942+
) # check there are no non-random generator SharedVariables in the frozen inputs
943+
944+
unfrozen_shared_inputs = [
945+
inpt
946+
for inpt in graph_inputs([test_forecast_model.forecast_combined])
947+
if isinstance(inpt, SharedVariable)
948+
and not isinstance(inpt.get_value(), np.random.Generator)
949+
]
950+
951+
# Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data
952+
assert len(unfrozen_shared_inputs) == 1
953+
assert unfrozen_shared_inputs[0].name == "data_exog"
954+
955+
data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars}
956+
957+
with test_forecast_model:
958+
dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
959+
pm.set_data(
960+
{"data_exog": scenario} | {"data": dummy_obs_data},
961+
coords={"data_time": np.arange(len(forecast_index))},
932962
)
933-
934-
test_forecast_model = exog_ss_mod._build_forecast_model(
935-
time_index=time_index,
936-
t0=t0,
937-
forecast_index=forecast_index,
938-
scenario=scenario,
939-
filter_output="smoothed",
940-
mvn_method="svd",
963+
idata_forecast = pm.sample_posterior_predictive(
964+
idata_exog, var_names=["x0_slice", "P0_slice"]
941965
)
942966

943-
data_after_build_forecast_model.append(
944-
{d.name: d.get_value() for d in test_forecast_model.data_vars}
945-
)
946-
# Change the data here
947-
with test_forecast_model:
948-
dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
949-
pm.set_data(
950-
{"data_exog": scenario} | {"data": dummy_obs_data},
951-
coords={"data_time": np.arange(len(forecast_index))},
952-
)
967+
np.testing.assert_allclose(
968+
unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1))
969+
) # ensure the replaced data matches the exogenous data
953970

954-
# Ensure first change in data did not affect second change ( not sure this makes sense since the forecast method will rebuild forecast model every time you call it)
955971
for k in data_before_build_forecast_model.keys():
956-
for data_before_build_forecast_scenario_specific in data_after_build_forecast_model:
957-
assert (
958-
data_before_build_forecast_model[k].mean()
959-
== data_before_build_forecast_scenario_specific[k].mean()
960-
)
972+
assert ( # check that the data needed to init the forecasts doesn't change
973+
data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean()
974+
)
975+
976+
# Check that the frozen states and covariances correctly match the sliced index
977+
np.testing.assert_allclose(
978+
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
979+
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
980+
)
981+
np.testing.assert_allclose(
982+
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
983+
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
984+
)
961985

962986

963987
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")

0 commit comments

Comments
 (0)