Skip to content
130 changes: 78 additions & 52 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization(

return scenario

def _build_forecast_model(
self, time_index, t0, forecast_index, scenario, filter_output, mvn_method
):
filter_time_dim = TIME_DIM
temp_coords = self._fit_coords.copy()

dims = None
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]

t0_idx = np.flatnonzero(time_index == t0)[0]

temp_coords["data_time"] = time_index
temp_coords[TIME_DIM] = forecast_index

mu_dims, cov_dims = None, None
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
mu_dims = ["data_time", ALL_STATE_DIM]
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
data_dims=["data_time", OBS_STATE_DIM],
)

group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
mu, cov = grouped_outputs[group_idx]

sub_dict = {
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
for data_var in forecast_model.data_vars
}

missing_data_vars = np.setdiff1d(
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
)
if missing_data_vars.size > 0:
raise ValueError(f"{missing_data_vars} data used for fitting not found!")

mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)

x0 = pm.Deterministic(
"x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
)
P0 = pm.Deterministic(
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)

_ = LinearGaussianStateSpace(
"forecast",
x0,
P0,
*matrices,
steps=len(forecast_index),
dims=dims,
sequence_names=self.kalman_filter.seq_names,
k_endog=self.k_endog,
append_x0=False,
method=mvn_method,
)

return forecast_model

def forecast(
self,
idata: InferenceData,
Expand Down Expand Up @@ -2139,8 +2202,6 @@ def forecast(
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.

"""
filter_time_dim = TIME_DIM

_validate_filter_arg(filter_output)

compile_kwargs = kwargs.pop("compile_kwargs", {})
Expand Down Expand Up @@ -2185,58 +2246,23 @@ def forecast(
use_scenario_index=use_scenario_index,
)
scenario = self._finalize_scenario_initialization(scenario, forecast_index)
temp_coords = self._fit_coords.copy()

dims = None
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]

t0_idx = np.flatnonzero(time_index == t0)[0]

temp_coords["data_time"] = time_index
temp_coords[TIME_DIM] = forecast_index

mu_dims, cov_dims = None, None
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
mu_dims = ["data_time", ALL_STATE_DIM]
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
scenario=scenario,
data_dims=["data_time", OBS_STATE_DIM],
)

for name in self.data_names:
if name in scenario.keys():
pm.set_data(
{"data": np.zeros((len(forecast_index), self.k_endog))},
coords={"data_time": np.arange(len(forecast_index))},
)
break

group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
mu, cov = grouped_outputs[group_idx]

x0 = pm.Deterministic(
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
)
P0 = pm.Deterministic(
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)
forecast_model = self._build_forecast_model(
time_index=time_index,
t0=t0,
forecast_index=forecast_index,
scenario=scenario,
filter_output=filter_output,
mvn_method=mvn_method,
)

_ = LinearGaussianStateSpace(
"forecast",
x0,
P0,
*matrices,
steps=len(forecast_index),
dims=dims,
sequence_names=self.kalman_filter.seq_names,
k_endog=self.k_endog,
append_x0=False,
method=mvn_method,
)
with forecast_model:
if scenario is not None:
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
pm.set_data(
scenario | {"data": dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))},
)

forecast_model.rvs_to_initial_values = {
k: None for k in forecast_model.rvs_to_initial_values.keys()
Expand Down
101 changes: 96 additions & 5 deletions tests/statespace/core/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import pytest

from numpy.testing import assert_allclose
from pymc.testing import mock_sample_setup_and_teardown
from pytensor.compile import SharedVariable
from pytensor.graph.basic import graph_inputs

from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
from pymc_extras.statespace.models import structural as st
Expand All @@ -30,6 +33,7 @@
floatX = pytensor.config.floatX
nile = load_nile_test_data()
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)


def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
Expand Down Expand Up @@ -170,7 +174,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
)
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])

exog_ss_mod.build_statespace_graph(exog_data["y"])
exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True)

return struct_model

Expand Down Expand Up @@ -212,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):


@pytest.fixture(scope="session")
def idata(pymc_mod, rng):
def idata(pymc_mod, rng, mock_pymc_sample):
with pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand All @@ -222,7 +226,7 @@ def idata(pymc_mod, rng):


@pytest.fixture(scope="session")
def idata_exog(exog_pymc_mod, rng):
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
with exog_pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand All @@ -231,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng):


@pytest.fixture(scope="session")
def idata_no_exog(pymc_mod_no_exog, rng):
def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample):
with pymc_mod_no_exog:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand All @@ -240,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng):


@pytest.fixture(scope="session")
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample):
with pymc_mod_no_exog_dt:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
Expand Down Expand Up @@ -895,6 +899,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
assert_allclose(regression_effect, regression_effect_expected)


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog):
data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars}

scenario = pd.DataFrame(
{
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
"x1": rng.choice(2, size=10, replace=True).astype(float),
}
)
scenario.set_index("date", inplace=True)

time_index = exog_ss_mod._get_fit_time_index()
t0, forecast_index = exog_ss_mod._build_forecast_index(
time_index=time_index,
start=exog_data.index[-1],
end=scenario.index[-1],
scenario=scenario,
)

test_forecast_model = exog_ss_mod._build_forecast_model(
time_index=time_index,
t0=t0,
forecast_index=forecast_index,
scenario=scenario,
filter_output="predicted",
mvn_method="svd",
)

frozen_shared_inputs = [
inpt
for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice])
if isinstance(inpt, SharedVariable)
and not isinstance(inpt.get_value(), np.random.Generator)
]

assert (
len(frozen_shared_inputs) == 0
) # check there are no non-random generator SharedVariables in the frozen inputs

unfrozen_shared_inputs = [
inpt
for inpt in graph_inputs([test_forecast_model.forecast_combined])
if isinstance(inpt, SharedVariable)
and not isinstance(inpt.get_value(), np.random.Generator)
]

# Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data
assert len(unfrozen_shared_inputs) == 1
assert unfrozen_shared_inputs[0].name == "data_exog"

data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars}

with test_forecast_model:
dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
pm.set_data(
{"data_exog": scenario} | {"data": dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))},
)
idata_forecast = pm.sample_posterior_predictive(
idata_exog, var_names=["x0_slice", "P0_slice"]
)

np.testing.assert_allclose(
unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1))
) # ensure the replaced data matches the exogenous data

for k in data_before_build_forecast_model.keys():
assert ( # check that the data needed to init the forecasts doesn't change
data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean()
)

# Check that the frozen states and covariances correctly match the sliced index
np.testing.assert_allclose(
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
)
np.testing.assert_allclose(
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
)


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
Expand Down