From 439e980c16f172c4d83c665ef02adbfa49fbf4c7 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Mon, 9 Jun 2025 04:33:44 -0600 Subject: [PATCH 1/8] fixed bug in statespace forecast method when exogenous variables are present. --- pymc_extras/statespace/core/statespace.py | 45 ++++++++++++++++++----- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 9342ff90d..eb2c65dd7 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -2203,21 +2203,34 @@ def forecast( with pm.Model(coords=temp_coords) as forecast_model: (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( - scenario=scenario, data_dims=["data_time", OBS_STATE_DIM], ) - for name in self.data_names: - if name in scenario.keys(): - pm.set_data( - {"data": np.zeros((len(forecast_index), self.k_endog))}, - coords={"data_time": np.arange(len(forecast_index))}, - ) - break - group_idx = FILTER_OUTPUT_TYPES.index(filter_output) mu, cov = grouped_outputs[group_idx] + if scenario is not None: + sub_dict = { + forecast_model[data_name]: pt.as_tensor_variable( + x=self._exog_data_info[data_name]["value"], name=data_name + ) + for data_name in self.data_names + } + + # Will this always be named "data"? + sub_dict[forecast_model["data"]] = pt.as_tensor_variable( + self._fit_data, name="data" + ) + else: + # same here will it always be named data? + sub_dict = { + forecast_model["data"]: pt.as_tensor_variable( + self._fit_data.astype(np.float64), name="data" + ) + } + + mu, cov = graph_replace([mu, cov], replace=sub_dict, strict=True) + x0 = pm.Deterministic( "x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None ) @@ -2225,6 +2238,20 @@ def forecast( "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None ) + if scenario is not None: + for name in self.data_names: + if name in scenario.keys(): + pm.set_data({name: scenario[name]}) + + for name in self.data_names: + if name in scenario.keys(): + # same here will it always be named data? + pm.set_data( + {"data": np.zeros((len(forecast_index), self.k_endog))}, + coords={"data_time": np.arange(len(forecast_index))}, + ) + break + _ = LinearGaussianStateSpace( "forecast", x0, From c2cf547aac4dfb3ee531468c1d973b0f13b3ade5 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Mon, 9 Jun 2025 04:58:19 -0600 Subject: [PATCH 2/8] updated solution to handle input shapes correctly --- pymc_extras/statespace/core/statespace.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index eb2c65dd7..9532345e3 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -2212,20 +2212,21 @@ def forecast( if scenario is not None: sub_dict = { forecast_model[data_name]: pt.as_tensor_variable( - x=self._exog_data_info[data_name]["value"], name=data_name + x=np.atleast_2d(self._exog_data_info[data_name]["value"].T).T, + name=data_name, ) for data_name in self.data_names } # Will this always be named "data"? sub_dict[forecast_model["data"]] = pt.as_tensor_variable( - self._fit_data, name="data" + np.atleast_2d(self._fit_data.T).T, name="data" ) else: # same here will it always be named data? sub_dict = { forecast_model["data"]: pt.as_tensor_variable( - self._fit_data.astype(np.float64), name="data" + np.atleast_2d(self._fit_data.T).T, name="data" ) } From 027de4195bafa6e8d36aa309a12aeb9cb579cdff Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Mon, 9 Jun 2025 08:38:41 -0600 Subject: [PATCH 3/8] simplified fix, renamed mu and cov for transparancy and added a check for the graph replacements --- pymc_extras/statespace/core/statespace.py | 51 ++++++++--------------- 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 9532345e3..53a4c839d 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -2209,49 +2209,32 @@ def forecast( group_idx = FILTER_OUTPUT_TYPES.index(filter_output) mu, cov = grouped_outputs[group_idx] - if scenario is not None: - sub_dict = { - forecast_model[data_name]: pt.as_tensor_variable( - x=np.atleast_2d(self._exog_data_info[data_name]["value"].T).T, - name=data_name, - ) - for data_name in self.data_names - } + sub_dict = { + data_var: pt.as_tensor_variable(data_var.get_value(), name="data") + for data_var in forecast_model.data_vars + } - # Will this always be named "data"? - sub_dict[forecast_model["data"]] = pt.as_tensor_variable( - np.atleast_2d(self._fit_data.T).T, name="data" - ) - else: - # same here will it always be named data? - sub_dict = { - forecast_model["data"]: pt.as_tensor_variable( - np.atleast_2d(self._fit_data.T).T, name="data" - ) - } + replacements_diff = np.setdiff1d( + ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] + ) + if replacements_diff.size > 0: + raise ValueError(f"{replacements_diff} data used for fitting not found!") - mu, cov = graph_replace([mu, cov], replace=sub_dict, strict=True) + mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) x0 = pm.Deterministic( - "x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None + "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None ) P0 = pm.Deterministic( - "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None + "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None ) if scenario is not None: - for name in self.data_names: - if name in scenario.keys(): - pm.set_data({name: scenario[name]}) - - for name in self.data_names: - if name in scenario.keys(): - # same here will it always be named data? - pm.set_data( - {"data": np.zeros((len(forecast_index), self.k_endog))}, - coords={"data_time": np.arange(len(forecast_index))}, - ) - break + dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) + pm.set_data( + scenario | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) _ = LinearGaussianStateSpace( "forecast", From d1964096e340d900b69594dc9156fb584fd8ca21 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 10 Jun 2025 01:09:39 +0800 Subject: [PATCH 4/8] Refactor model builder logic out of `forecast` method --- pymc_extras/statespace/core/statespace.py | 142 ++++++++++++---------- 1 file changed, 78 insertions(+), 64 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 53a4c839d..f0924c885 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -2047,6 +2047,76 @@ def _finalize_scenario_initialization( return scenario + def _build_forecast_model( + self, time_index, t0, forecast_index, scenario, filter_output, mvn_method + ): + filter_time_dim = TIME_DIM + temp_coords = self._fit_coords.copy() + + dims = None + if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + + t0_idx = np.flatnonzero(time_index == t0)[0] + + temp_coords["data_time"] = time_index + temp_coords[TIME_DIM] = forecast_index + + mu_dims, cov_dims = None, None + if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): + mu_dims = ["data_time", ALL_STATE_DIM] + cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] + + with pm.Model(coords=temp_coords) as forecast_model: + (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( + data_dims=["data_time", OBS_STATE_DIM], + ) + + group_idx = FILTER_OUTPUT_TYPES.index(filter_output) + mu, cov = grouped_outputs[group_idx] + + sub_dict = { + data_var: pt.as_tensor_variable(data_var.get_value(), name="data") + for data_var in forecast_model.data_vars + } + + replacements_diff = np.setdiff1d( + ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] + ) + if replacements_diff.size > 0: + raise ValueError(f"{replacements_diff} data used for fitting not found!") + + mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) + + x0 = pm.Deterministic( + "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None + ) + P0 = pm.Deterministic( + "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None + ) + + if scenario is not None: + dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) + pm.set_data( + scenario | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) + + _ = LinearGaussianStateSpace( + "forecast", + x0, + P0, + *matrices, + steps=len(forecast_index), + dims=dims, + sequence_names=self.kalman_filter.seq_names, + k_endog=self.k_endog, + append_x0=False, + method=mvn_method, + ) + + return forecast_model + def forecast( self, idata: InferenceData, @@ -2139,8 +2209,6 @@ def forecast( the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`. """ - filter_time_dim = TIME_DIM - _validate_filter_arg(filter_output) compile_kwargs = kwargs.pop("compile_kwargs", {}) @@ -2185,69 +2253,15 @@ def forecast( use_scenario_index=use_scenario_index, ) scenario = self._finalize_scenario_initialization(scenario, forecast_index) - temp_coords = self._fit_coords.copy() - - dims = None - if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): - dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] - - t0_idx = np.flatnonzero(time_index == t0)[0] - - temp_coords["data_time"] = time_index - temp_coords[TIME_DIM] = forecast_index - - mu_dims, cov_dims = None, None - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): - mu_dims = ["data_time", ALL_STATE_DIM] - cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] - - with pm.Model(coords=temp_coords) as forecast_model: - (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( - data_dims=["data_time", OBS_STATE_DIM], - ) - group_idx = FILTER_OUTPUT_TYPES.index(filter_output) - mu, cov = grouped_outputs[group_idx] - - sub_dict = { - data_var: pt.as_tensor_variable(data_var.get_value(), name="data") - for data_var in forecast_model.data_vars - } - - replacements_diff = np.setdiff1d( - ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] - ) - if replacements_diff.size > 0: - raise ValueError(f"{replacements_diff} data used for fitting not found!") - - mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) - - x0 = pm.Deterministic( - "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None - ) - P0 = pm.Deterministic( - "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None - ) - - if scenario is not None: - dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) - pm.set_data( - scenario | {"data": dummy_obs_data}, - coords={"data_time": np.arange(len(forecast_index))}, - ) - - _ = LinearGaussianStateSpace( - "forecast", - x0, - P0, - *matrices, - steps=len(forecast_index), - dims=dims, - sequence_names=self.kalman_filter.seq_names, - k_endog=self.k_endog, - append_x0=False, - method=mvn_method, - ) + forecast_model = self._build_forecast_model( + time_index=time_index, + t0=t0, + forecast_index=forecast_index, + scenario=scenario, + filter_output=filter_output, + mvn_method=mvn_method, + ) forecast_model.rvs_to_initial_values = { k: None for k in forecast_model.rvs_to_initial_values.keys() From d41a10994642786f46a411c47baf857b9c86a812 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Tue, 10 Jun 2025 05:00:13 -0600 Subject: [PATCH 5/8] made slight change with _build_forecast_model and created a test case --- pymc_extras/statespace/core/statespace.py | 21 ++++----- tests/statespace/core/test_statespace.py | 52 +++++++++++++++++++++++ 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index f0924c885..96e1e9b52 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -2080,11 +2080,11 @@ def _build_forecast_model( for data_var in forecast_model.data_vars } - replacements_diff = np.setdiff1d( + missing_data_vars = np.setdiff1d( ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] ) - if replacements_diff.size > 0: - raise ValueError(f"{replacements_diff} data used for fitting not found!") + if missing_data_vars.size > 0: + raise ValueError(f"{missing_data_vars} data used for fitting not found!") mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) @@ -2095,13 +2095,6 @@ def _build_forecast_model( "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None ) - if scenario is not None: - dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) - pm.set_data( - scenario | {"data": dummy_obs_data}, - coords={"data_time": np.arange(len(forecast_index))}, - ) - _ = LinearGaussianStateSpace( "forecast", x0, @@ -2263,6 +2256,14 @@ def forecast( mvn_method=mvn_method, ) + with forecast_model: + if scenario is not None: + dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) + pm.set_data( + scenario | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) + forecast_model.rvs_to_initial_values = { k: None for k in forecast_model.rvs_to_initial_values.keys() } diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 6a77c1514..1c256841f 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -895,6 +895,58 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): assert_allclose(regression_effect, regression_effect_expected) +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") +@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") +@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") +def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data): + data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars} + + scenario1 = pd.DataFrame( + { + "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), + "x1": rng.choice(2, size=10, replace=True).astype(float), + } + ) + scenario1.set_index("date", inplace=True) + + scenario2 = pd.DataFrame( + { + "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), + "x1": np.zeros(shape=(10,)), + } + ) + scenario2.set_index("date", inplace=True) + + for scenario in [scenario1, scenario2]: + time_index = exog_ss_mod._get_fit_time_index() + t0, forecast_index = exog_ss_mod._build_forecast_index( + time_index=time_index, + start=exog_data.index[-1], + end=scenario.index[-1], + scenario=scenario, + ) + + test_forecast_model = exog_ss_mod._build_forecast_model( + time_index=time_index, + t0=t0, + forecast_index=forecast_index, + scenario=scenario, + filter_output="smoothed", + mvn_method="svd", + ) + + data_after_build_forecast_model = { + d.name: d.get_value() for d in test_forecast_model.data_vars + } + for k in data_before_build_forecast_model.keys(): + assert ( + data_before_build_forecast_model[k].mean() + == data_after_build_forecast_model[k].mean() + ) + + @pytest.mark.filterwarnings("ignore:Provided data contains missing values") @pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") From 86352747d55e73d062138bc554d5792f9f9e7df5 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Tue, 10 Jun 2025 19:35:44 -0600 Subject: [PATCH 6/8] made change to test_build_forecast_model() to ensure data is replaced with pm.set_data method --- tests/statespace/core/test_statespace.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 1c256841f..a35430c17 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -901,6 +901,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): @pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") @pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data): + # Want to make sure this remains the same even after updating data using pm.set_data() data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars} scenario1 = pd.DataFrame( @@ -919,6 +920,8 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data): ) scenario2.set_index("date", inplace=True) + data_after_build_forecast_model = [] + for scenario in [scenario1, scenario2]: time_index = exog_ss_mod._get_fit_time_index() t0, forecast_index = exog_ss_mod._build_forecast_index( @@ -937,13 +940,23 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data): mvn_method="svd", ) - data_after_build_forecast_model = { - d.name: d.get_value() for d in test_forecast_model.data_vars - } - for k in data_before_build_forecast_model.keys(): + data_after_build_forecast_model.append( + {d.name: d.get_value() for d in test_forecast_model.data_vars} + ) + # Change the data here + with test_forecast_model: + dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog)) + pm.set_data( + {"data_exog": scenario} | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, + ) + + # Ensure first change in data did not affect second change ( not sure this makes sense since the forecast method will rebuild forecast model every time you call it) + for k in data_before_build_forecast_model.keys(): + for data_before_build_forecast_scenario_specific in data_after_build_forecast_model: assert ( data_before_build_forecast_model[k].mean() - == data_after_build_forecast_model[k].mean() + == data_before_build_forecast_scenario_specific[k].mean() ) From 4119a0db3e2beec7792be9bcd1554dd17e30ca16 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Wed, 11 Jun 2025 16:27:54 -0600 Subject: [PATCH 7/8] added additional checks to test_build_forecast_model --- tests/statespace/core/test_statespace.py | 110 ++++++++++++++--------- 1 file changed, 67 insertions(+), 43 deletions(-) diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index a35430c17..4bca1bb63 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -9,6 +9,8 @@ import pytest from numpy.testing import assert_allclose +from pytensor.compile import SharedVariable +from pytensor.graph.basic import graph_inputs from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace from pymc_extras.statespace.models import structural as st @@ -170,7 +172,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data): ) beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"]) - exog_ss_mod.build_statespace_graph(exog_data["y"]) + exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True) return struct_model @@ -900,64 +902,86 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") @pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") @pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") -def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data): - # Want to make sure this remains the same even after updating data using pm.set_data() +def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog): data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars} - scenario1 = pd.DataFrame( + scenario = pd.DataFrame( { "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), "x1": rng.choice(2, size=10, replace=True).astype(float), } ) - scenario1.set_index("date", inplace=True) + scenario.set_index("date", inplace=True) - scenario2 = pd.DataFrame( - { - "date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"), - "x1": np.zeros(shape=(10,)), - } + time_index = exog_ss_mod._get_fit_time_index() + t0, forecast_index = exog_ss_mod._build_forecast_index( + time_index=time_index, + start=exog_data.index[-1], + end=scenario.index[-1], + scenario=scenario, ) - scenario2.set_index("date", inplace=True) - data_after_build_forecast_model = [] + test_forecast_model = exog_ss_mod._build_forecast_model( + time_index=time_index, + t0=t0, + forecast_index=forecast_index, + scenario=scenario, + filter_output="predicted", + mvn_method="svd", + ) - for scenario in [scenario1, scenario2]: - time_index = exog_ss_mod._get_fit_time_index() - t0, forecast_index = exog_ss_mod._build_forecast_index( - time_index=time_index, - start=exog_data.index[-1], - end=scenario.index[-1], - scenario=scenario, + frozen_shared_inputs = [ + inpt + for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice]) + if isinstance(inpt, SharedVariable) + and not isinstance(inpt.get_value(), np.random.Generator) + ] + + assert ( + len(frozen_shared_inputs) == 0 + ) # check there are no non-random generator SharedVariables in the frozen inputs + + unfrozen_shared_inputs = [ + inpt + for inpt in graph_inputs([test_forecast_model.forecast_combined]) + if isinstance(inpt, SharedVariable) + and not isinstance(inpt.get_value(), np.random.Generator) + ] + + # Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data + assert len(unfrozen_shared_inputs) == 1 + assert unfrozen_shared_inputs[0].name == "data_exog" + + data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars} + + with test_forecast_model: + dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog)) + pm.set_data( + {"data_exog": scenario} | {"data": dummy_obs_data}, + coords={"data_time": np.arange(len(forecast_index))}, ) - - test_forecast_model = exog_ss_mod._build_forecast_model( - time_index=time_index, - t0=t0, - forecast_index=forecast_index, - scenario=scenario, - filter_output="smoothed", - mvn_method="svd", + idata_forecast = pm.sample_posterior_predictive( + idata_exog, var_names=["x0_slice", "P0_slice"] ) - data_after_build_forecast_model.append( - {d.name: d.get_value() for d in test_forecast_model.data_vars} - ) - # Change the data here - with test_forecast_model: - dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog)) - pm.set_data( - {"data_exog": scenario} | {"data": dummy_obs_data}, - coords={"data_time": np.arange(len(forecast_index))}, - ) + np.testing.assert_allclose( + unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1)) + ) # ensure the replaced data matches the exogenous data - # Ensure first change in data did not affect second change ( not sure this makes sense since the forecast method will rebuild forecast model every time you call it) for k in data_before_build_forecast_model.keys(): - for data_before_build_forecast_scenario_specific in data_after_build_forecast_model: - assert ( - data_before_build_forecast_model[k].mean() - == data_before_build_forecast_scenario_specific[k].mean() - ) + assert ( # check that the data needed to init the forecasts doesn't change + data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean() + ) + + # Check that the frozen states and covariances correctly match the sliced index + np.testing.assert_allclose( + idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values, + idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values, + ) + np.testing.assert_allclose( + idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values, + idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values, + ) @pytest.mark.filterwarnings("ignore:Provided data contains missing values") From 252f70ab466823594db5422a099a3f47c32e23f4 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Thu, 12 Jun 2025 04:45:03 -0600 Subject: [PATCH 8/8] added mock_sample_setup_and_teardown to statespace tests --- tests/statespace/core/test_statespace.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 4bca1bb63..bfcd114ae 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -9,6 +9,7 @@ import pytest from numpy.testing import assert_allclose +from pymc.testing import mock_sample_setup_and_teardown from pytensor.compile import SharedVariable from pytensor.graph.basic import graph_inputs @@ -32,6 +33,7 @@ floatX = pytensor.config.floatX nile = load_nile_test_data() ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES +mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown) def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None): @@ -214,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng): @pytest.fixture(scope="session") -def idata(pymc_mod, rng): +def idata(pymc_mod, rng, mock_pymc_sample): with pymc_mod: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -224,7 +226,7 @@ def idata(pymc_mod, rng): @pytest.fixture(scope="session") -def idata_exog(exog_pymc_mod, rng): +def idata_exog(exog_pymc_mod, rng, mock_pymc_sample): with exog_pymc_mod: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -233,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng): @pytest.fixture(scope="session") -def idata_no_exog(pymc_mod_no_exog, rng): +def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample): with pymc_mod_no_exog: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) @@ -242,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng): @pytest.fixture(scope="session") -def idata_no_exog_dt(pymc_mod_no_exog_dt, rng): +def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample): with pymc_mod_no_exog_dt: idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)