Skip to content

Commit 43d778a

Browse files
Dekermanjianjessegrabowski
authored andcommitted
Pr 451 - modified and added tests to statespace (pymc-devs#466)
* Use `set_data` in forecast * Ignore new numpy matmul warnings in tests * Tracking down data bug * resolved merge conflicts environment-test.yml * added and modified statespace tests * added logic to update static shape of target when forecasting with exogenous variables * make sure updated dummy target has correct dimensions * Revert test env name change * Add some checks to `test_foreacast_valid_index` --------- Co-authored-by: jessegrabowski <[email protected]>
1 parent 1eda115 commit 43d778a

File tree

3 files changed

+360
-234
lines changed

3 files changed

+360
-234
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,10 @@ def _validate_forecast_args(
16991699
raise ValueError(
17001700
"Must specify one of either periods or end unless use_scenario_index=True"
17011701
)
1702+
if periods is None and end is None and not use_scenario_index:
1703+
raise ValueError(
1704+
"Must specify one of either periods or end unless use_scenario_index=True"
1705+
)
17021706
if periods is not None and end is not None:
17031707
raise ValueError("Must specify exactly one of either periods or end")
17041708
if scenario is None and use_scenario_index:
@@ -2246,23 +2250,58 @@ def forecast(
22462250
use_scenario_index=use_scenario_index,
22472251
)
22482252
scenario = self._finalize_scenario_initialization(scenario, forecast_index)
2253+
temp_coords = self._fit_coords.copy()
22492254

2250-
forecast_model = self._build_forecast_model(
2251-
time_index=time_index,
2252-
t0=t0,
2253-
forecast_index=forecast_index,
2254-
scenario=scenario,
2255-
filter_output=filter_output,
2256-
mvn_method=mvn_method,
2257-
)
2255+
dims = None
2256+
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2257+
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
22582258

2259-
with forecast_model:
2260-
if scenario is not None:
2261-
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2262-
pm.set_data(
2263-
scenario | {"data": dummy_obs_data},
2264-
coords={"data_time": np.arange(len(forecast_index))},
2265-
)
2259+
t0_idx = np.flatnonzero(time_index == t0)[0]
2260+
2261+
temp_coords["data_time"] = time_index
2262+
temp_coords[TIME_DIM] = forecast_index
2263+
2264+
mu_dims, cov_dims = None, None
2265+
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2266+
mu_dims = ["data_time", ALL_STATE_DIM]
2267+
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2268+
2269+
with pm.Model(coords=temp_coords) as forecast_model:
2270+
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2271+
scenario=scenario,
2272+
data_dims=["data_time", OBS_STATE_DIM],
2273+
)
2274+
2275+
for name in self.data_names:
2276+
if name in scenario.keys():
2277+
pm.set_data(
2278+
{"data": np.zeros((len(forecast_index), self.k_endog))},
2279+
coords={"data_time": np.arange(len(forecast_index))},
2280+
)
2281+
break
2282+
2283+
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2284+
mu, cov = grouped_outputs[group_idx]
2285+
2286+
x0 = pm.Deterministic(
2287+
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2288+
)
2289+
P0 = pm.Deterministic(
2290+
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2291+
)
2292+
2293+
_ = LinearGaussianStateSpace(
2294+
"forecast",
2295+
x0,
2296+
P0,
2297+
*matrices,
2298+
steps=len(forecast_index),
2299+
dims=dims,
2300+
mode=self._fit_mode,
2301+
sequence_names=self.kalman_filter.seq_names,
2302+
k_endog=self.k_endog,
2303+
append_x0=False,
2304+
)
22662305

22672306
forecast_model.rvs_to_initial_values = {
22682307
k: None for k in forecast_model.rvs_to_initial_values.keys()

tests/statespace/core/test_statespace.py

Lines changed: 16 additions & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -152,28 +152,23 @@ def exog_data(rng):
152152

153153
df.loc[[1, 3, 9], ["y"]] = np.nan
154154
return df.set_index("date")
155-
156-
157-
@pytest.fixture(scope="session")
158-
def exog_data_mv(rng):
155+
def exog_data(rng):
159156
# simulate data
160157
df = pd.DataFrame(
161158
{
162159
"date": pd.date_range(start="2023-05-01", end="2023-05-10", freq="D"),
163160
"x1": rng.choice(2, size=10, replace=True).astype(float),
164-
"y1": rng.normal(size=(10,)),
165-
"y2": rng.normal(size=(10,)),
161+
"y": rng.normal(size=(10,)),
166162
}
167163
)
168164

169-
df.loc[[1, 3, 9], ["y1"]] = np.nan
170-
df.loc[[3, 5, 7], ["y2"]] = np.nan
165+
df.loc[[1, 3, 9], ["y"]] = np.nan
171166
return df.set_index("date")
172167

173168

174169
@pytest.fixture(scope="session")
175170
def exog_ss_mod(exog_data):
176-
level_trend = st.LevelTrendComponent(name="trend", order=1, innovations_order=[0])
171+
level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
177172
exog = st.RegressionComponent(
178173
name="exog", # Name of this exogenous variable component
179174
k_exog=1, # Only one exogenous variable now
@@ -185,83 +180,23 @@ def exog_ss_mod(exog_data):
185180
return combined_model.build()
186181

187182

188-
@pytest.fixture(scope="session")
189-
def exog_ss_mod_mv(exog_data_mv):
190-
level_trend = st.LevelTrendComponent(
191-
name="trend", order=1, innovations_order=[0], observed_state_names=["y1", "y2"]
192-
)
193-
exog = st.RegressionComponent(
194-
name="exog", # Name of this exogenous variable component
195-
k_exog=1, # Only one exogenous variable now
196-
innovations=False, # Typically fixed effect (no stochastic evolution)
197-
state_names=exog_data_mv[["x1"]].columns.tolist(),
198-
observed_state_names=["y1", "y2"],
199-
)
200-
201-
combined_model = level_trend + exog
202-
return combined_model.build()
203-
204-
205-
@pytest.fixture(scope="session")
206-
def ss_mod_multi_component(rng):
207-
ll = st.LevelTrendComponent(
208-
name="trend", order=2, innovations_order=1, observed_state_names=["y1", "y2"]
209-
)
210-
exog = st.RegressionComponent(
211-
name="exog",
212-
innovations=True,
213-
state_names=["x1"],
214-
)
215-
ar = st.AutoregressiveComponent(observed_state_names=["y1"])
216-
cycle = st.CycleComponent(cycle_length=2, observed_state_names=["y1", "y2"], innovations=True)
217-
season = st.TimeSeasonality(season_length=2, observed_state_names=["y1"], innovations=True)
218-
219-
fseason = st.FrequencySeasonality(
220-
season_length=2, observed_state_names=["y1"], innovations=True
221-
)
222-
measure_error = st.MeasurementError(observed_state_names=["y1", "y2"])
223-
return (ll + exog + ar + cycle + season + fseason + measure_error).build()
224-
225-
226183
@pytest.fixture(scope="session")
227184
def exog_pymc_mod(exog_ss_mod, exog_data):
228185
# define pymc model
229186
with pm.Model(coords=exog_ss_mod.coords) as struct_model:
230187
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=4, dims=["state"])
231188
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
232189

233-
initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["state_trend"])
190+
initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])
234191

235192
data_exog = pm.Data(
236-
"data_exog", exog_data["x1"].values[:, None], dims=["time", "state_exog"]
193+
"data_exog", exog_data["x1"].values[:, None], dims=["time", "exog_state"]
237194
)
238-
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["state_exog"])
195+
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])
239196

240-
exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True)
197+
exog_ss_mod.build_statespace_graph(exog_data["y"])
241198

242199
return struct_model
243-
244-
245-
@pytest.fixture(scope="session")
246-
def exog_pymc_mod_mv(exog_ss_mod_mv, exog_data_mv):
247-
# define pymc model
248-
with pm.Model(coords=exog_ss_mod_mv.coords) as struct_model:
249-
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=4, dims=["state"])
250-
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
251-
252-
initial_trend = pm.Normal(
253-
"initial_trend", mu=[0], sigma=[0.005], dims=["endog_trend", "state_trend"]
254-
)
255-
256-
data_exog = pm.Data(
257-
"data_exog", exog_data_mv["x1"].values[:, None], dims=["time", "state_exog"]
258-
)
259-
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["endog_exog", "state_exog"])
260-
261-
exog_ss_mod_mv.build_statespace_graph(
262-
exog_data_mv[["y1", "y2"]], save_kalman_filter_outputs_in_idata=True
263-
)
264-
265200
return struct_model
266201

267202

@@ -1036,13 +971,19 @@ def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng,
1036971
assert forecast_idx[0] == (t0 + delta)
1037972

1038973

974+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
975+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
1039976
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
1040977
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
1041978
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
1042979
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
1043980
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
1044981
@pytest.mark.parametrize("start", [None, -1, 5])
982+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
983+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
984+
@pytest.mark.parametrize("start", [None, -1, 5])
1045985
def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
986+
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["x1"])
1046987
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["x1"])
1047988
scenario.iloc[5, 0] = 1e9
1048989

@@ -1051,7 +992,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
1051992
)
1052993

1053994
components = exog_ss_mod.extract_components_from_idata(forecast_idata)
1054-
level = components.forecast_latent.sel(state="trend[level]")
995+
level = components.forecast_latent.sel(state="LevelTrend[level]")
1055996
betas = components.forecast_latent.sel(state=["exog[x1]"])
1056997

1057998
scenario.index.name = "time"
@@ -1060,6 +1001,7 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
10601001
.to_xarray()
10611002
.rename({"level_0": "state"})
10621003
.assign_coords(state=["exog[x1]"])
1004+
.assign_coords(state=["exog[x1]"])
10631005
)
10641006

10651007
regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
@@ -1068,138 +1010,6 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
10681010
assert_allclose(regression_effect, regression_effect_expected)
10691011

10701012

1071-
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
1072-
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
1073-
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
1074-
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
1075-
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
1076-
@pytest.mark.parametrize("start", [None, -1, 5])
1077-
def test_forecast_with_exog_data_mv(rng, exog_ss_mod_mv, idata_exog_mv, start):
1078-
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["x1"])
1079-
scenario.iloc[5, 0] = 1e9
1080-
1081-
forecast_idata = exog_ss_mod_mv.forecast(
1082-
idata_exog_mv, start=start, periods=10, random_seed=rng, scenario=scenario
1083-
)
1084-
1085-
components = exog_ss_mod_mv.extract_components_from_idata(forecast_idata)
1086-
level_y1 = components.forecast_latent.sel(state="trend[level[y1]]")
1087-
level_y2 = components.forecast_latent.sel(state="trend[level[y2]]")
1088-
betas_y1 = components.forecast_latent.sel(state=["exog[x1[y1]]"])
1089-
betas_y2 = components.forecast_latent.sel(state=["exog[x1[y2]]"])
1090-
1091-
scenario.index.name = "time"
1092-
scenario_xr_y1 = (
1093-
scenario.unstack()
1094-
.to_xarray()
1095-
.rename({"level_0": "state"})
1096-
.assign_coords(state=["exog[x1[y1]]"])
1097-
)
1098-
1099-
scenario_xr_y2 = (
1100-
scenario.unstack()
1101-
.to_xarray()
1102-
.rename({"level_0": "state"})
1103-
.assign_coords(state=["exog[x1[y2]]"])
1104-
)
1105-
1106-
regression_effect_y1 = forecast_idata.forecast_observed.isel(observed_state=0) - level_y1
1107-
regression_effect_expected_y1 = (betas_y1 * scenario_xr_y1).sum(dim=["state"])
1108-
1109-
regression_effect_y2 = forecast_idata.forecast_observed.isel(observed_state=1) - level_y2
1110-
regression_effect_expected_y2 = (betas_y2 * scenario_xr_y2).sum(dim=["state"])
1111-
1112-
np.testing.assert_allclose(regression_effect_y1, regression_effect_expected_y1)
1113-
np.testing.assert_allclose(regression_effect_y2, regression_effect_expected_y2)
1114-
1115-
1116-
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
1117-
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
1118-
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
1119-
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
1120-
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
1121-
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog):
1122-
data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars}
1123-
1124-
scenario = pd.DataFrame(
1125-
{
1126-
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
1127-
"x1": rng.choice(2, size=10, replace=True).astype(float),
1128-
}
1129-
)
1130-
scenario.set_index("date", inplace=True)
1131-
1132-
time_index = exog_ss_mod._get_fit_time_index()
1133-
t0, forecast_index = exog_ss_mod._build_forecast_index(
1134-
time_index=time_index,
1135-
start=exog_data.index[-1],
1136-
end=scenario.index[-1],
1137-
scenario=scenario,
1138-
)
1139-
1140-
test_forecast_model = exog_ss_mod._build_forecast_model(
1141-
time_index=time_index,
1142-
t0=t0,
1143-
forecast_index=forecast_index,
1144-
scenario=scenario,
1145-
filter_output="predicted",
1146-
mvn_method="svd",
1147-
)
1148-
1149-
frozen_shared_inputs = [
1150-
inpt
1151-
for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice])
1152-
if isinstance(inpt, SharedVariable)
1153-
and not isinstance(inpt.get_value(), np.random.Generator)
1154-
]
1155-
1156-
assert (
1157-
len(frozen_shared_inputs) == 0
1158-
) # check there are no non-random generator SharedVariables in the frozen inputs
1159-
1160-
unfrozen_shared_inputs = [
1161-
inpt
1162-
for inpt in graph_inputs([test_forecast_model.forecast_combined])
1163-
if isinstance(inpt, SharedVariable)
1164-
and not isinstance(inpt.get_value(), np.random.Generator)
1165-
]
1166-
1167-
# Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data
1168-
assert len(unfrozen_shared_inputs) == 1
1169-
assert unfrozen_shared_inputs[0].name == "data_exog"
1170-
1171-
data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars}
1172-
1173-
with test_forecast_model:
1174-
dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
1175-
pm.set_data(
1176-
{"data_exog": scenario} | {"data": dummy_obs_data},
1177-
coords={"data_time": np.arange(len(forecast_index))},
1178-
)
1179-
idata_forecast = pm.sample_posterior_predictive(
1180-
idata_exog, var_names=["x0_slice", "P0_slice"]
1181-
)
1182-
1183-
np.testing.assert_allclose(
1184-
unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1))
1185-
) # ensure the replaced data matches the exogenous data
1186-
1187-
for k in data_before_build_forecast_model.keys():
1188-
assert ( # check that the data needed to init the forecasts doesn't change
1189-
data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean()
1190-
)
1191-
1192-
# Check that the frozen states and covariances correctly match the sliced index
1193-
np.testing.assert_allclose(
1194-
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
1195-
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
1196-
)
1197-
np.testing.assert_allclose(
1198-
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
1199-
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
1200-
)
1201-
1202-
12031013
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
12041014
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
12051015
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
@@ -1231,16 +1041,3 @@ def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data):
12311041

12321042
assert forecasts.forecast_latent.shape[2] == n_periods
12331043
assert forecasts.forecast_observed.shape[2] == n_periods
1234-
1235-
1236-
def test_param_dims_coords(ss_mod_multi_component):
1237-
for param in ss_mod_multi_component.param_names:
1238-
shape = ss_mod_multi_component.param_info[param]["shape"]
1239-
dims = ss_mod_multi_component.param_dims.get(param, None)
1240-
if len(shape) == 0:
1241-
assert dims is None
1242-
continue
1243-
for i, s in zip(shape, dims):
1244-
assert i == len(
1245-
ss_mod_multi_component.coords[s]
1246-
), f"Mismatch between shape {i} and dimension {s}"

0 commit comments

Comments
 (0)