Skip to content

Commit 9456bb2

Browse files
committed
Extend tests for the case duration > 1
1 parent dc44ac5 commit 9456bb2

File tree

1 file changed

+77
-61
lines changed

1 file changed

+77
-61
lines changed

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 77 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -14,62 +14,69 @@
1414

1515

1616
@pytest.mark.parametrize("s", [10, 25, 50])
17+
@pytest.mark.parametrize("d", [1, 2, 3])
1718
@pytest.mark.parametrize("innovations", [True, False])
1819
@pytest.mark.parametrize("remove_first_state", [True, False])
1920
@pytest.mark.filterwarnings(
2021
"ignore:divide by zero encountered in matmul:RuntimeWarning",
2122
"ignore:overflow encountered in matmul:RuntimeWarning",
2223
"ignore:invalid value encountered in matmul:RuntimeWarning",
2324
)
24-
def test_time_seasonality(s, innovations, remove_first_state, rng):
25+
def test_time_seasonality(s, d, innovations, remove_first_state, rng):
2526
def random_word(rng):
2627
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
2728

28-
state_names = [random_word(rng) for _ in range(s)]
29+
state_names = [random_word(rng) for _ in range(s * d)]
2930
mod = st.TimeSeasonality(
3031
season_length=s,
32+
duration=d,
3133
innovations=innovations,
3234
name="season",
3335
state_names=state_names,
3436
remove_first_state=remove_first_state,
3537
)
36-
x0 = np.zeros(mod.k_states, dtype=config.floatX)
38+
x0 = np.zeros(mod.k_states // mod.duration, dtype=config.floatX)
3739
x0[0] = 1
3840

3941
params = {"coefs_season": x0}
4042
if innovations:
4143
params["sigma_season"] = 0.0
4244

43-
x, y = simulate_from_numpy_model(mod, rng, params)
45+
x, y = simulate_from_numpy_model(mod, rng, params, steps=100 * mod.duration)
4446
y = y.ravel()
4547
if not innovations:
46-
assert_pattern_repeats(y, s, atol=ATOL, rtol=RTOL)
48+
assert_pattern_repeats(y, s * d, atol=ATOL, rtol=RTOL)
4749

4850
# Check coords
4951
mod = mod.build(verbose=False)
5052
_assert_basic_coords_correct(mod)
51-
test_slice = slice(1, None) if remove_first_state else slice(None)
53+
test_slice = slice(d, None) if remove_first_state else slice(None)
5254
assert mod.coords["state_season"] == state_names[test_slice]
5355

5456

57+
@pytest.mark.parametrize("d", [1, 2, 3])
5558
@pytest.mark.parametrize(
5659
"remove_first_state", [True, False], ids=["remove_first_state", "keep_first_state"]
5760
)
58-
def test_time_seasonality_multiple_observed(rng, remove_first_state):
61+
def test_time_seasonality_multiple_observed(rng, d, remove_first_state):
5962
s = 3
60-
state_names = [f"state_{i}" for i in range(s)]
63+
state_names = [f"state_{i}_{j}" for i in range(s) for j in range(d)]
6164
mod = st.TimeSeasonality(
6265
season_length=s,
66+
duration=d,
6367
innovations=True,
6468
name="season",
6569
state_names=state_names,
6670
observed_state_names=["data_1", "data_2"],
6771
remove_first_state=remove_first_state,
6872
)
69-
x0 = np.zeros((mod.k_endog, mod.k_states // mod.k_endog), dtype=config.floatX)
73+
x0 = np.zeros((mod.k_endog, mod.k_states // mod.k_endog // mod.duration), dtype=config.floatX)
7074

7175
expected_states = [
72-
f"state_{i}[data_{j}]" for j in range(1, 3) for i in range(int(remove_first_state), s)
76+
f"state_{i}_{j}[data_{k}]"
77+
for k in range(1, 3)
78+
for i in range(int(remove_first_state), s)
79+
for j in range(d)
7380
]
7481
assert mod.state_names == expected_states
7582
assert mod.shock_names == ["season[data_1]", "season[data_2]"]
@@ -79,9 +86,9 @@ def test_time_seasonality_multiple_observed(rng, remove_first_state):
7986

8087
params = {"coefs_season": x0, "sigma_season": np.array([0.0, 0.0], dtype=config.floatX)}
8188

82-
x, y = simulate_from_numpy_model(mod, rng, params, steps=123)
83-
assert_pattern_repeats(y[:, 0], s, atol=ATOL, rtol=RTOL)
84-
assert_pattern_repeats(y[:, 1], s, atol=ATOL, rtol=RTOL)
89+
x, y = simulate_from_numpy_model(mod, rng, params, steps=123 * d)
90+
assert_pattern_repeats(y[:, 0], s * d, atol=ATOL, rtol=RTOL)
91+
assert_pattern_repeats(y[:, 1], s * d, atol=ATOL, rtol=RTOL)
8592

8693
mod = mod.build(verbose=False)
8794
x0, *_, T, Z, R, _, Q = mod._unpack_statespace_with_placeholders()
@@ -97,36 +104,38 @@ def test_time_seasonality_multiple_observed(rng, remove_first_state):
97104
params["sigma_season"] = np.array([0.1, 0.8], dtype=config.floatX)
98105
x0, T, Z, R, Q = fn(**params)
99106

107+
# Because the dimension of the observed states is 2,
108+
# the expected T is the diagonal block matrix [[T0, 0], [0, T0]]
109+
# where T0 is the transition matrix we would have if the
110+
# seasonality were not multiple observed.
111+
mod0 = st.TimeSeasonality(season_length=s, duration=d, remove_first_state=remove_first_state)
112+
T0 = mod0.ssm["transition"].eval()
113+
100114
if remove_first_state:
101-
expected_x0 = np.array([1.0, 0.0, 2.0, 0.0])
102-
103-
expected_T = np.array(
104-
[
105-
[-1.0, -1.0, 0.0, 0.0],
106-
[1.0, 0.0, 0.0, 0.0],
107-
[0.0, 0.0, -1.0, -1.0],
108-
[0.0, 0.0, 1.0, 0.0],
109-
]
115+
expected_x0 = np.repeat(np.array([1.0, 0.0, 2.0, 0.0]), d)
116+
expected_T = np.block(
117+
[[T0, np.zeros((d * (s - 1), d * (s - 1)))], [np.zeros((d * (s - 1), d * (s - 1))), T0]]
118+
)
119+
expected_R = np.array(
120+
[[1.0, 1.0]] + [[0.0, 0.0]] * (2 * d - 1) + [[1.0, 1.0]] + [[0.0, 0.0]] * (2 * d - 1)
110121
)
111-
expected_R = np.array([[1.0, 1.0], [0.0, 0.0], [1.0, 1.0], [0.0, 0.0]])
112-
expected_Z = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]])
122+
Z0 = np.zeros((2, d * (s - 1)))
123+
Z0[0, 0] = 1
124+
Z1 = np.zeros((2, d * (s - 1)))
125+
Z1[1, 0] = 1
126+
expected_Z = np.block([[Z0, Z1]])
113127

114128
else:
115-
expected_x0 = np.array([1.0, 0.0, 0.0, 2.0, 0.0, 0.0])
116-
expected_T = np.array(
117-
[
118-
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
119-
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
120-
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
121-
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
122-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
123-
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
124-
]
125-
)
129+
expected_x0 = np.repeat(np.array([1.0, 0.0, 0.0, 2.0, 0.0, 0.0]), d)
130+
expected_T = np.block([[T0, np.zeros((s * d, s * d))], [np.zeros((s * d, s * d)), T0]])
126131
expected_R = np.array(
127-
[[1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0]]
132+
[[1.0, 1.0]] + [[0.0, 0.0]] * (s * d - 1) + [[1.0, 1.0]] + [[0.0, 0.0]] * (s * d - 1)
128133
)
129-
expected_Z = np.array([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0, 0.0]])
134+
Z0 = np.zeros((2, s * d))
135+
Z0[0, 0] = 1
136+
Z1 = np.zeros((2, s * d))
137+
Z1[1, 0] = 1
138+
expected_Z = np.block([[Z0, Z1]])
130139

131140
expected_Q = np.array([[0.1**2, 0.0], [0.0, 0.8**2]])
132141

@@ -137,20 +146,24 @@ def test_time_seasonality_multiple_observed(rng, remove_first_state):
137146
np.testing.assert_allclose(matrix, expected)
138147

139148

140-
def test_add_two_time_seasonality_different_observed(rng):
149+
@pytest.mark.parametrize("d1", [1, 2, 3])
150+
@pytest.mark.parametrize("d2", [1, 2, 3])
151+
def test_add_two_time_seasonality_different_observed(rng, d1, d2):
141152
mod1 = st.TimeSeasonality(
142153
season_length=3,
154+
duration=d1,
143155
innovations=True,
144156
name="season1",
145-
state_names=[f"state_{i}" for i in range(3)],
157+
state_names=[f"state_{i}_{j}" for i in range(3) for j in range(d1)],
146158
observed_state_names=["data_1"],
147159
remove_first_state=False,
148160
)
149161
mod2 = st.TimeSeasonality(
150162
season_length=5,
163+
duration=d2,
151164
innovations=True,
152165
name="season2",
153-
state_names=[f"state_{i}" for i in range(5)],
166+
state_names=[f"state_{i}_{j}" for i in range(5) for j in range(d2)],
154167
observed_state_names=["data_2"],
155168
)
156169

@@ -164,18 +177,22 @@ def test_add_two_time_seasonality_different_observed(rng):
164177
"initial_state_cov": np.eye(mod.k_states, dtype=config.floatX),
165178
}
166179

167-
x, y = simulate_from_numpy_model(mod, rng, params, steps=3 * 5 * 5)
168-
assert_pattern_repeats(y[:, 0], 3, atol=ATOL, rtol=RTOL)
169-
assert_pattern_repeats(y[:, 1], 5, atol=ATOL, rtol=RTOL)
180+
x, y = simulate_from_numpy_model(mod, rng, params, steps=3 * 5 * 5 * d1 * d2)
181+
assert_pattern_repeats(y[:, 0], 3 * d1, atol=ATOL, rtol=RTOL)
182+
assert_pattern_repeats(y[:, 1], 5 * d2, atol=ATOL, rtol=RTOL)
170183

171184
assert mod.state_names == [
172-
"state_0[data_1]",
173-
"state_1[data_1]",
174-
"state_2[data_1]",
175-
"state_1[data_2]",
176-
"state_2[data_2]",
177-
"state_3[data_2]",
178-
"state_4[data_2]",
185+
item
186+
for sublist in [
187+
[f"state_0_{j}[data_1]" for j in range(d1)],
188+
[f"state_1_{j}[data_1]" for j in range(d1)],
189+
[f"state_2_{j}[data_1]" for j in range(d1)],
190+
[f"state_1_{j}[data_2]" for j in range(d2)],
191+
[f"state_2_{j}[data_2]" for j in range(d2)],
192+
[f"state_3_{j}[data_2]" for j in range(d2)],
193+
[f"state_4_{j}[data_2]" for j in range(d2)],
194+
]
195+
for item in sublist
179196
]
180197

181198
assert mod.shock_names == ["season1[data_1]", "season2[data_2]"]
@@ -194,20 +211,19 @@ def test_add_two_time_seasonality_different_observed(rng):
194211
)
195212

196213
np.testing.assert_allclose(
197-
np.array([1.0, 0.0, 0.0, 3.0, 0.0, 0.0, 1.2]), x0, atol=ATOL, rtol=RTOL
214+
np.repeat(np.array([1.0, 0.0, 0.0, 3.0, 0.0, 0.0, 1.2]), [d1, d1, d1, d2, d2, d2, d2]),
215+
x0,
216+
atol=ATOL,
217+
rtol=RTOL,
198218
)
199219

220+
# The transition matrix T of mod is expected to be [[T1, 0], [0, T2]],
221+
# where T1 and T2 are the transition matrices of mod1 and mod2, respectively.
222+
T1 = mod1.ssm["transition"].eval()
223+
T2 = mod2.ssm["transition"].eval()
200224
np.testing.assert_allclose(
201-
np.array(
202-
[
203-
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
204-
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
205-
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
206-
[0.0, 0.0, 0.0, -1.0, -1.0, -1.0, -1.0],
207-
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
208-
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
209-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
210-
]
225+
np.block(
226+
[[T1, np.zeros((T1.shape[0], T2.shape[1]))], [np.zeros((T2.shape[0], T1.shape[1])), T2]]
211227
),
212228
T,
213229
atol=ATOL,

0 commit comments

Comments
 (0)