|
9 | 9 | import pytest |
10 | 10 |
|
11 | 11 | from numpy.testing import assert_allclose |
| 12 | +from pytensor.compile import SharedVariable |
| 13 | +from pytensor.graph.basic import graph_inputs |
12 | 14 |
|
13 | 15 | from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace |
14 | 16 | from pymc_extras.statespace.models import structural as st |
@@ -170,7 +172,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data): |
170 | 172 | ) |
171 | 173 | beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"]) |
172 | 174 |
|
173 | | - exog_ss_mod.build_statespace_graph(exog_data["y"]) |
| 175 | + exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True) |
174 | 176 |
|
175 | 177 | return struct_model |
176 | 178 |
|
@@ -900,64 +902,86 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): |
900 | 902 | @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") |
901 | 903 | @pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") |
902 | 904 | @pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") |
903 | | -def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data): |
904 | | - # Want to make sure this remains the same even after updating data using pm.set_data() |
| 905 | +def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog): |
905 | 906 | data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars} |
906 | 907 |
|
907 | | - scenario1 = pd.DataFrame( |
| 908 | + scenario = pd.DataFrame( |
908 | 909 | { |
909 | 910 | "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), |
910 | 911 | "x1": rng.choice(2, size=10, replace=True).astype(float), |
911 | 912 | } |
912 | 913 | ) |
913 | | - scenario1.set_index("date", inplace=True) |
| 914 | + scenario.set_index("date", inplace=True) |
914 | 915 |
|
915 | | - scenario2 = pd.DataFrame( |
916 | | - { |
917 | | - "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), |
918 | | - "x1": np.zeros(shape=(10,)), |
919 | | - } |
| 916 | + time_index = exog_ss_mod._get_fit_time_index() |
| 917 | + t0, forecast_index = exog_ss_mod._build_forecast_index( |
| 918 | + time_index=time_index, |
| 919 | + start=exog_data.index[-1], |
| 920 | + end=scenario.index[-1], |
| 921 | + scenario=scenario, |
920 | 922 | ) |
921 | | - scenario2.set_index("date", inplace=True) |
922 | 923 |
|
923 | | - data_after_build_forecast_model = [] |
| 924 | + test_forecast_model = exog_ss_mod._build_forecast_model( |
| 925 | + time_index=time_index, |
| 926 | + t0=t0, |
| 927 | + forecast_index=forecast_index, |
| 928 | + scenario=scenario, |
| 929 | + filter_output="predicted", |
| 930 | + mvn_method="svd", |
| 931 | + ) |
924 | 932 |
|
925 | | - for scenario in [scenario1, scenario2]: |
926 | | - time_index = exog_ss_mod._get_fit_time_index() |
927 | | - t0, forecast_index = exog_ss_mod._build_forecast_index( |
928 | | - time_index=time_index, |
929 | | - start=exog_data.index[-1], |
930 | | - end=scenario.index[-1], |
931 | | - scenario=scenario, |
| 933 | + frozen_shared_inputs = [ |
| 934 | + inpt |
| 935 | + for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice]) |
| 936 | + if isinstance(inpt, SharedVariable) |
| 937 | + and not isinstance(inpt.get_value(), np.random.Generator) |
| 938 | + ] |
| 939 | + |
| 940 | + assert ( |
| 941 | + len(frozen_shared_inputs) == 0 |
| 942 | + ) # check there are no non-random generator SharedVariables in the frozen inputs |
| 943 | + |
| 944 | + unfrozen_shared_inputs = [ |
| 945 | + inpt |
| 946 | + for inpt in graph_inputs([test_forecast_model.forecast_combined]) |
| 947 | + if isinstance(inpt, SharedVariable) |
| 948 | + and not isinstance(inpt.get_value(), np.random.Generator) |
| 949 | + ] |
| 950 | + |
| 951 | + # Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data |
| 952 | + assert len(unfrozen_shared_inputs) == 1 |
| 953 | + assert unfrozen_shared_inputs[0].name == "data_exog" |
| 954 | + |
| 955 | + data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars} |
| 956 | + |
| 957 | + with test_forecast_model: |
| 958 | + dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog)) |
| 959 | + pm.set_data( |
| 960 | + {"data_exog": scenario} | {"data": dummy_obs_data}, |
| 961 | + coords={"data_time": np.arange(len(forecast_index))}, |
932 | 962 | ) |
933 | | - |
934 | | - test_forecast_model = exog_ss_mod._build_forecast_model( |
935 | | - time_index=time_index, |
936 | | - t0=t0, |
937 | | - forecast_index=forecast_index, |
938 | | - scenario=scenario, |
939 | | - filter_output="smoothed", |
940 | | - mvn_method="svd", |
| 963 | + idata_forecast = pm.sample_posterior_predictive( |
| 964 | + idata_exog, var_names=["x0_slice", "P0_slice"] |
941 | 965 | ) |
942 | 966 |
|
943 | | - data_after_build_forecast_model.append( |
944 | | - {d.name: d.get_value() for d in test_forecast_model.data_vars} |
945 | | - ) |
946 | | - # Change the data here |
947 | | - with test_forecast_model: |
948 | | - dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog)) |
949 | | - pm.set_data( |
950 | | - {"data_exog": scenario} | {"data": dummy_obs_data}, |
951 | | - coords={"data_time": np.arange(len(forecast_index))}, |
952 | | - ) |
| 967 | + np.testing.assert_allclose( |
| 968 | + unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1)) |
| 969 | + ) # ensure the replaced data matches the exogenous data |
953 | 970 |
|
954 | | - # Ensure first change in data did not affect second change ( not sure this makes sense since the forecast method will rebuild forecast model every time you call it) |
955 | 971 | for k in data_before_build_forecast_model.keys(): |
956 | | - for data_before_build_forecast_scenario_specific in data_after_build_forecast_model: |
957 | | - assert ( |
958 | | - data_before_build_forecast_model[k].mean() |
959 | | - == data_before_build_forecast_scenario_specific[k].mean() |
960 | | - ) |
| 972 | + assert ( # check that the data needed to init the forecasts doesn't change |
| 973 | + data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean() |
| 974 | + ) |
| 975 | + |
| 976 | + # Check that the frozen states and covariances correctly match the sliced index |
| 977 | + np.testing.assert_allclose( |
| 978 | + idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values, |
| 979 | + idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values, |
| 980 | + ) |
| 981 | + np.testing.assert_allclose( |
| 982 | + idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values, |
| 983 | + idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values, |
| 984 | + ) |
961 | 985 |
|
962 | 986 |
|
963 | 987 | @pytest.mark.filterwarnings("ignore:Provided data contains missing values") |
|
0 commit comments