Skip to content

Commit d41a109

Browse files
committed
made slight change with _build_forecast_model and created a test case
1 parent d196409 commit d41a109

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,11 +2080,11 @@ def _build_forecast_model(
20802080
for data_var in forecast_model.data_vars
20812081
}
20822082

2083-
replacements_diff = np.setdiff1d(
2083+
missing_data_vars = np.setdiff1d(
20842084
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
20852085
)
2086-
if replacements_diff.size > 0:
2087-
raise ValueError(f"{replacements_diff} data used for fitting not found!")
2086+
if missing_data_vars.size > 0:
2087+
raise ValueError(f"{missing_data_vars} data used for fitting not found!")
20882088

20892089
mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)
20902090

@@ -2095,13 +2095,6 @@ def _build_forecast_model(
20952095
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
20962096
)
20972097

2098-
if scenario is not None:
2099-
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2100-
pm.set_data(
2101-
scenario | {"data": dummy_obs_data},
2102-
coords={"data_time": np.arange(len(forecast_index))},
2103-
)
2104-
21052098
_ = LinearGaussianStateSpace(
21062099
"forecast",
21072100
x0,
@@ -2263,6 +2256,14 @@ def forecast(
22632256
mvn_method=mvn_method,
22642257
)
22652258

2259+
with forecast_model:
2260+
if scenario is not None:
2261+
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2262+
pm.set_data(
2263+
scenario | {"data": dummy_obs_data},
2264+
coords={"data_time": np.arange(len(forecast_index))},
2265+
)
2266+
22662267
forecast_model.rvs_to_initial_values = {
22672268
k: None for k in forecast_model.rvs_to_initial_values.keys()
22682269
}

tests/statespace/core/test_statespace.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,58 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
895895
assert_allclose(regression_effect, regression_effect_expected)
896896

897897

898+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
899+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
900+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
901+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
902+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
903+
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data):
904+
data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars}
905+
906+
scenario1 = pd.DataFrame(
907+
{
908+
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
909+
"x1": rng.choice(2, size=10, replace=True).astype(float),
910+
}
911+
)
912+
scenario1.set_index("date", inplace=True)
913+
914+
scenario2 = pd.DataFrame(
915+
{
916+
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
917+
"x1": np.zeros(shape=(10,)),
918+
}
919+
)
920+
scenario2.set_index("date", inplace=True)
921+
922+
for scenario in [scenario1, scenario2]:
923+
time_index = exog_ss_mod._get_fit_time_index()
924+
t0, forecast_index = exog_ss_mod._build_forecast_index(
925+
time_index=time_index,
926+
start=exog_data.index[-1],
927+
end=scenario.index[-1],
928+
scenario=scenario,
929+
)
930+
931+
test_forecast_model = exog_ss_mod._build_forecast_model(
932+
time_index=time_index,
933+
t0=t0,
934+
forecast_index=forecast_index,
935+
scenario=scenario,
936+
filter_output="smoothed",
937+
mvn_method="svd",
938+
)
939+
940+
data_after_build_forecast_model = {
941+
d.name: d.get_value() for d in test_forecast_model.data_vars
942+
}
943+
for k in data_before_build_forecast_model.keys():
944+
assert (
945+
data_before_build_forecast_model[k].mean()
946+
== data_after_build_forecast_model[k].mean()
947+
)
948+
949+
898950
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
899951
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
900952
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")

0 commit comments

Comments
 (0)