Skip to content

Commit f90709c

Browse files
Test for stationary initialization
1 parent af9804d commit f90709c

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

tests/statespace/test_ETS.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,12 @@ def test_statespace_matches_statsmodels(rng, order: tuple[str, str, str], params
235235
for i in range(1, seasonal_periods):
236236
sm_test_values[f"initial_seasonal.L{i}"] = test_values["initial_seasonal"][i]
237237

238-
x0 = np.r_[
239-
0, *[test_values[name] for name in ["initial_level", "initial_trend", "initial_seasonal"]]
238+
vals = [
239+
np.atleast_1d(test_values[name])
240+
for name in ["initial_level", "initial_trend", "initial_seasonal"]
240241
]
242+
x0 = np.concatenate([[0.0], *vals])
243+
241244
mask = [True, True, order[1] != "N", *(order[2] != "N",) * seasonal_periods]
242245

243246
sm_mod.initialize_known(initial_state=x0[mask], initial_state_cov=np.eye(mod.k_states))
@@ -367,3 +370,42 @@ def test_ETS_with_multiple_endog(rng, order, params, dense_cov):
367370
raise ValueError(f"You forgot {name} !")
368371

369372
cursor += single_mod.k_states
373+
374+
375+
def test_ETS_stationary_initialization():
376+
mod = BayesianETS(
377+
order=("A", "Ad", "A"),
378+
seasonal_periods=4,
379+
stationary_initialization=True,
380+
initialization_dampening=0.66,
381+
)
382+
383+
matrices = mod._unpack_statespace_with_placeholders()
384+
inputs = list(explicit_graph_inputs(matrices))
385+
input_names = [x.name for x in inputs]
386+
387+
# Make sure the stationary_dampening dummy variables was completely rewritten away
388+
assert "stationary_dampening" not in input_names
389+
390+
# P0 should have been removed from param names
391+
assert "P0" not in mod.param_names
392+
assert "P0" not in mod.param_info.keys()
393+
394+
f = pytensor.function(inputs, matrices, mode="FAST_COMPILE")
395+
test_values = f(**{x.name: np.full(x.type.shape, 0.5) for x in inputs})
396+
outputs = {name: val for name, val in zip(LONG_MATRIX_NAMES, test_values)}
397+
398+
# Make sure that the transition matrix has ones in the expected positions, not the model dampening factor
399+
assert outputs["transition"][1, 1] == 1.0
400+
assert outputs["transition"][2, 2] == 0.5 # phi = 0.5 -- trend is dampened anyway
401+
assert outputs["transition"][3, -1] == 1.0
402+
403+
# P0 should be equal to the solution to the Lyapunov equation using the dampening factors in the transition matrix
404+
T_stationary = outputs["transition"].copy()
405+
T_stationary[1, 1] = mod.initialization_dampening
406+
T_stationary[3, -1] = mod.initialization_dampening
407+
408+
R, Q = outputs["selection"], outputs["state_cov"]
409+
P0_expected = linalg.solve_discrete_lyapunov(T_stationary, R @ Q @ R.T)
410+
411+
assert_allclose(outputs["initial_state_cov"], P0_expected)

0 commit comments

Comments
 (0)