diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 9342ff90d..96e1e9b52 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization( return scenario + def _build_forecast_model( + self, time_index, t0, forecast_index, scenario, filter_output, mvn_method + ): + filter_time_dim = TIME_DIM + temp_coords = self._fit_coords.copy() + + dims = None + if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + + t0_idx = np.flatnonzero(time_index == t0)[0] + + temp_coords["data_time"] = time_index + temp_coords[TIME_DIM] = forecast_index + + mu_dims, cov_dims = None, None + if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): + mu_dims = ["data_time", ALL_STATE_DIM] + cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] + + with pm.Model(coords=temp_coords) as forecast_model: + (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( + data_dims=["data_time", OBS_STATE_DIM], + ) + + group_idx = FILTER_OUTPUT_TYPES.index(filter_output) + mu, cov = grouped_outputs[group_idx] + + sub_dict = { + data_var: pt.as_tensor_variable(data_var.get_value(), name="data") + for data_var in forecast_model.data_vars + } + + missing_data_vars = np.setdiff1d( + ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] + ) + if missing_data_vars.size > 0: + raise ValueError(f"{missing_data_vars} data used for fitting not found!") + + mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) + + x0 = pm.Deterministic( + "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None + ) + P0 = pm.Deterministic( + "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None + ) + + _ = LinearGaussianStateSpace( + "forecast", + x0, + P0, + *matrices, + steps=len(forecast_index), + dims=dims, + sequence_names=self.kalman_filter.seq_names, + k_endog=self.k_endog, + append_x0=False, + method=mvn_method, + ) + + return forecast_model + def forecast( self, idata: InferenceData, @@ -2139,8 +2202,6 @@ def forecast( the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`. """ - filter_time_dim = TIME_DIM - _validate_filter_arg(filter_output) compile_kwargs = kwargs.pop("compile_kwargs", {}) @@ -2185,58 +2246,23 @@ def forecast( use_scenario_index=use_scenario_index, ) scenario = self._finalize_scenario_initialization(scenario, forecast_index) - temp_coords = self._fit_coords.copy() - - dims = None - if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): - dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] - - t0_idx = np.flatnonzero(time_index == t0)[0] - - temp_coords["data_time"] = time_index - temp_coords[TIME_DIM] = forecast_index - - mu_dims, cov_dims = None, None - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): - mu_dims = ["data_time", ALL_STATE_DIM] - cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] - - with pm.Model(coords=temp_coords) as forecast_model: - (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( - scenario=scenario, - data_dims=["data_time", OBS_STATE_DIM], - ) - - for name in self.data_names: - if name in scenario.keys(): - pm.set_data( - {"data": np.zeros((len(forecast_index), self.k_endog))}, - coords={"data_time": np.arange(len(forecast_index))}, - ) - break - group_idx = FILTER_OUTPUT_TYPES.index(filter_output) - mu, cov = grouped_outputs[group_idx] - - x0 = pm.Deterministic( - "x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None - ) - P0 = pm.Deterministic( - "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None - ) + forecast_model = self._build_forecast_model( + time_index=time_index, + t0=t0, + forecast_index=forecast_index, + scenario=scenario, + filter_output=filter_output, + mvn_method=mvn_method, + ) - _ = LinearGaussianStateSpace( - "forecast", - x0, - P0, - *matrices, - steps=len(forecast_index), - dims=dims, - sequence_names=self.kalman_filter.seq_names, - k_endog=self.k_endog, - append_x0=False, - method=mvn_method, - ) + with forecast_model: + if scenario is not None: + dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) + pm.set_data( + scenario | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) forecast_model.rvs_to_initial_values = { k: None for k in forecast_model.rvs_to_initial_values.keys() diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 6a77c1514..bfcd114ae 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -9,6 +9,9 @@ import pytest from numpy.testing import assert_allclose +from pymc.testing import mock_sample_setup_and_teardown +from pytensor.compile import SharedVariable +from pytensor.graph.basic import graph_inputs from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace from pymc_extras.statespace.models import structural as st @@ -30,6 +33,7 @@ floatX = pytensor.config.floatX nile = load_nile_test_data() ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES +mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown) def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None): @@ -170,7 +174,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data): ) beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"]) - exog_ss_mod.build_statespace_graph(exog_data["y"]) + exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True) return struct_model @@ -212,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng): @pytest.fixture(scope="session") -def idata(pymc_mod, rng): +def idata(pymc_mod, rng, mock_pymc_sample): with pymc_mod: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -222,7 +226,7 @@ def idata(pymc_mod, rng): @pytest.fixture(scope="session") -def idata_exog(exog_pymc_mod, rng): +def idata_exog(exog_pymc_mod, rng, mock_pymc_sample): with exog_pymc_mod: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -231,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng): @pytest.fixture(scope="session") -def idata_no_exog(pymc_mod_no_exog, rng): +def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample): with pymc_mod_no_exog: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -240,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng): @pytest.fixture(scope="session") -def idata_no_exog_dt(pymc_mod_no_exog_dt, rng): +def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample): with pymc_mod_no_exog_dt: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -895,6 +899,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): assert_allclose(regression_effect, regression_effect_expected) +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") +@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") +@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") +def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog): + data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars} + + scenario = pd.DataFrame( + { + "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), + "x1": rng.choice(2, size=10, replace=True).astype(float), + } + ) + scenario.set_index("date", inplace=True) + + time_index = exog_ss_mod._get_fit_time_index() + t0, forecast_index = exog_ss_mod._build_forecast_index( + time_index=time_index, + start=exog_data.index[-1], + end=scenario.index[-1], + scenario=scenario, + ) + + test_forecast_model = exog_ss_mod._build_forecast_model( + time_index=time_index, + t0=t0, + forecast_index=forecast_index, + scenario=scenario, + filter_output="predicted", + mvn_method="svd", + ) + + frozen_shared_inputs = [ + inpt + for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice]) + if isinstance(inpt, SharedVariable) + and not isinstance(inpt.get_value(), np.random.Generator) + ] + + assert ( + len(frozen_shared_inputs) == 0 + ) # check there are no non-random generator SharedVariables in the frozen inputs + + unfrozen_shared_inputs = [ + inpt + for inpt in graph_inputs([test_forecast_model.forecast_combined]) + if isinstance(inpt, SharedVariable) + and not isinstance(inpt.get_value(), np.random.Generator) + ] + + # Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data + assert len(unfrozen_shared_inputs) == 1 + assert unfrozen_shared_inputs[0].name == "data_exog" + + data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars} + + with test_forecast_model: + dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog)) + pm.set_data( + {"data_exog": scenario} | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) + idata_forecast = pm.sample_posterior_predictive( + idata_exog, var_names=["x0_slice", "P0_slice"] + ) + + np.testing.assert_allclose( + unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1)) + ) # ensure the replaced data matches the exogenous data + + for k in data_before_build_forecast_model.keys(): + assert ( # check that the data needed to init the forecasts doesn't change + data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean() + ) + + # Check that the frozen states and covariances correctly match the sliced index + np.testing.assert_allclose( + idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values, + idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values, + ) + np.testing.assert_allclose( + idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values, + idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values, + ) + + @pytest.mark.filterwarnings("ignore:Provided data contains missing values") @pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")