Skip to content

Commit a8564b7

Browse files
Allow multiple observed in TimeSeasonality component
1 parent 0b20dbc commit a8564b7

File tree

2 files changed

+207
-19
lines changed

2 files changed

+207
-19
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -154,27 +154,41 @@ def __init__(
154154
# TODO: Can this be stashed and reconstructed automatically somehow?
155155
state_names.pop(0)
156156

157+
self.provided_state_names = state_names
158+
157159
k_states = season_length - int(self.remove_first_state)
160+
k_endog = len(observed_state_names)
161+
k_posdef = int(innovations)
158162

159163
super().__init__(
160164
name=name,
161-
k_endog=len(observed_state_names),
162-
k_states=k_states,
163-
k_posdef=int(innovations),
164-
state_names=state_names,
165+
k_endog=k_endog,
166+
k_states=k_states * k_endog,
167+
k_posdef=k_posdef * k_endog,
165168
observed_state_names=observed_state_names,
166169
measurement_error=False,
167170
combine_hidden_states=True,
168-
obs_state_idxs=np.r_[[1.0], np.zeros(k_states - 1)],
171+
obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
169172
)
170173

171174
def populate_component_properties(self):
175+
k_states = self.k_states // self.k_endog
176+
k_endog = self.k_endog
177+
178+
self.state_names = [
179+
f"{state_name}[{endog_name}]"
180+
for endog_name in self.observed_state_names
181+
for state_name in self.provided_state_names
182+
]
172183
self.param_names = [f"{self.name}_coefs"]
184+
173185
self.param_info = {
174186
f"{self.name}_coefs": {
175-
"shape": (self.k_states,),
187+
"shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
176188
"constraints": None,
177-
"dims": (f"{self.name}_state",),
189+
"dims": (f"{self.name}_state",)
190+
if k_endog == 1
191+
else (f"{self.name}_endog", f"{self.name}_state"),
178192
}
179193
}
180194
self.param_dims = {f"{self.name}_coefs": (f"{self.name}_state",)}
@@ -187,32 +201,41 @@ def populate_component_properties(self):
187201
"constraints": "Positive",
188202
"dims": None,
189203
}
190-
self.shock_names = [f"{self.name}"]
204+
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
191205

192206
def make_symbolic_graph(self) -> None:
207+
k_states = self.k_states // self.k_endog
208+
k_posdef = self.k_posdef // self.k_endog
209+
k_endog = self.k_endog
210+
193211
if self.remove_first_state:
194212
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
195213
# all previous states.
196-
T = np.eye(self.k_states, k=-1)
214+
T = np.eye(k_states, k=-1)
197215
T[0, :] = -1
198216
else:
199217
# In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
200218
# circulant matrix that cycles between the states.
201-
T = np.eye(self.k_states, k=1)
219+
T = np.eye(k_states, k=1)
202220
T[-1, 0] = 1
203221

204-
self.ssm["transition", :, :] = T
205-
self.ssm["design", 0, 0] = 1
222+
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
223+
224+
Z = pt.zeros((1, k_states))[0, 0].set(1)
225+
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
206226

207227
initial_states = self.make_and_register_variable(
208-
f"{self.name}_coefs", shape=(self.k_states,)
228+
f"{self.name}_coefs", shape=(k_states,) if k_endog == 1 else (k_endog, k_states)
209229
)
210-
self.ssm["initial_state", np.arange(self.k_states, dtype=int)] = initial_states
230+
self.ssm["initial_state", :] = initial_states.ravel()
211231

212232
if self.innovations:
213-
self.ssm["selection", 0, 0] = 1
214-
season_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=())
215-
cov_idx = ("state_cov", *np.diag_indices(1))
233+
R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
234+
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog)])
235+
season_sigma = self.make_and_register_variable(
236+
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
237+
)
238+
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog))
216239
self.ssm[cov_idx] = season_sigma**2
217240

218241

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
2+
import pytensor
23
import pytest
34

45
from pytensor import config
6+
from pytensor.graph.basic import explicit_graph_inputs
57

68
from pymc_extras.statespace.models import structural as st
79
from tests.statespace.models.structural.conftest import _assert_basic_coords_correct
@@ -35,7 +37,7 @@ def random_word(rng):
3537
x0[0] = 1
3638

3739
params = {"season_coefs": x0}
38-
if mod.innovations:
40+
if innovations:
3941
params["sigma_season"] = 0.0
4042

4143
x, y = simulate_from_numpy_model(mod, rng, params)
@@ -44,12 +46,175 @@ def random_word(rng):
4446
assert_pattern_repeats(y, s, atol=ATOL, rtol=RTOL)
4547

4648
# Check coords
47-
mod.build(verbose=False)
49+
mod = mod.build(verbose=False)
4850
_assert_basic_coords_correct(mod)
4951
test_slice = slice(1, None) if remove_first_state else slice(None)
5052
assert mod.coords["season_state"] == state_names[test_slice]
5153

5254

55+
@pytest.mark.parametrize(
56+
"remove_first_state", [True, False], ids=["remove_first_state", "keep_first_state"]
57+
)
58+
def test_time_seasonality_multiple_observed(rng, remove_first_state):
59+
s = 3
60+
state_names = [f"state_{i}" for i in range(s)]
61+
mod = st.TimeSeasonality(
62+
season_length=s,
63+
innovations=True,
64+
name="season",
65+
state_names=state_names,
66+
observed_state_names=["data_1", "data_2"],
67+
remove_first_state=remove_first_state,
68+
)
69+
x0 = np.zeros((mod.k_endog, mod.k_states // mod.k_endog), dtype=config.floatX)
70+
71+
expected_states = [
72+
f"state_{i}[data_{j}]" for j in range(1, 3) for i in range(int(remove_first_state), s)
73+
]
74+
assert mod.state_names == expected_states
75+
assert mod.shock_names == ["season[data_1]", "season[data_2]"]
76+
77+
x0[0, 0] = 1
78+
x0[1, 0] = 2.0
79+
80+
params = {"season_coefs": x0, "sigma_season": np.array([0.0, 0.0], dtype=config.floatX)}
81+
82+
x, y = simulate_from_numpy_model(mod, rng, params, steps=123)
83+
assert_pattern_repeats(y[:, 0], s, atol=ATOL, rtol=RTOL)
84+
assert_pattern_repeats(y[:, 1], s, atol=ATOL, rtol=RTOL)
85+
86+
mod = mod.build(verbose=False)
87+
x0, *_, T, Z, R, _, Q = mod._unpack_statespace_with_placeholders()
88+
89+
input_vars = explicit_graph_inputs([x0, T, Z, R, Q])
90+
91+
fn = pytensor.function(
92+
inputs=list(input_vars),
93+
outputs=[x0, T, Z, R, Q],
94+
mode="FAST_COMPILE",
95+
)
96+
97+
params["sigma_season"] = np.array([0.1, 0.8], dtype=config.floatX)
98+
x0, T, Z, R, Q = fn(**params)
99+
100+
if remove_first_state:
101+
expected_x0 = np.array([1.0, 0.0, 2.0, 0.0])
102+
103+
expected_T = np.array(
104+
[
105+
[-1.0, -1.0, 0.0, 0.0],
106+
[1.0, 0.0, 0.0, 0.0],
107+
[0.0, 0.0, -1.0, -1.0],
108+
[0.0, 0.0, 1.0, 0.0],
109+
]
110+
)
111+
expected_R = np.array([[1.0, 1.0], [0.0, 0.0], [1.0, 1.0], [0.0, 0.0]])
112+
expected_Z = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]])
113+
114+
else:
115+
expected_x0 = np.array([1.0, 0.0, 0.0, 2.0, 0.0, 0.0])
116+
expected_T = np.array(
117+
[
118+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
119+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
120+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
121+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
122+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
123+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
124+
]
125+
)
126+
expected_R = np.array(
127+
[[1.0, 1.0], [0.0, 0.0], [0.0, 0.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0]]
128+
)
129+
expected_Z = np.array([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0, 0.0]])
130+
131+
expected_Q = np.array([[0.1**2, 0.0], [0.0, 0.8**2]])
132+
133+
for matrix, expected in zip(
134+
[x0, T, Z, R, Q],
135+
[expected_x0, expected_T, expected_Z, expected_R, expected_Q],
136+
):
137+
np.testing.assert_allclose(matrix, expected)
138+
139+
140+
def test_add_two_time_seasonality_different_observed(rng):
141+
mod1 = st.TimeSeasonality(
142+
season_length=3,
143+
innovations=True,
144+
name="season1",
145+
state_names=[f"state_{i}" for i in range(3)],
146+
observed_state_names=["data_1"],
147+
remove_first_state=False,
148+
)
149+
mod2 = st.TimeSeasonality(
150+
season_length=5,
151+
innovations=True,
152+
name="season2",
153+
state_names=[f"state_{i}" for i in range(5)],
154+
observed_state_names=["data_2"],
155+
)
156+
157+
mod = (mod1 + mod2).build(verbose=False)
158+
159+
params = {
160+
"season1_coefs": np.array([1.0, 0.0, 0.0], dtype=config.floatX),
161+
"season2_coefs": np.array([3.0, 0.0, 0.0, 0.0], dtype=config.floatX),
162+
"sigma_season1": np.array(0.0, dtype=config.floatX),
163+
"sigma_season2": np.array(0.0, dtype=config.floatX),
164+
"initial_state_cov": np.eye(mod.k_states, dtype=config.floatX),
165+
}
166+
167+
x, y = simulate_from_numpy_model(mod, rng, params, steps=3 * 5 * 5)
168+
assert_pattern_repeats(y[:, 0], 3, atol=ATOL, rtol=RTOL)
169+
assert_pattern_repeats(y[:, 1], 5, atol=ATOL, rtol=RTOL)
170+
171+
assert mod.state_names == [
172+
"state_0[data_1]",
173+
"state_1[data_1]",
174+
"state_2[data_1]",
175+
"state_1[data_2]",
176+
"state_2[data_2]",
177+
"state_3[data_2]",
178+
"state_4[data_2]",
179+
]
180+
181+
assert mod.shock_names == ["season1[data_1]", "season2[data_2]"]
182+
183+
x0, *_, T = mod._unpack_statespace_with_placeholders()[:5]
184+
input_vars = explicit_graph_inputs([x0, T])
185+
fn = pytensor.function(
186+
inputs=list(input_vars),
187+
outputs=[x0, T],
188+
mode="FAST_COMPILE",
189+
)
190+
191+
x0, T = fn(
192+
season1_coefs=np.array([1.0, 0.0, 0.0], dtype=config.floatX),
193+
season2_coefs=np.array([3.0, 0.0, 0.0, 1.2], dtype=config.floatX),
194+
)
195+
196+
np.testing.assert_allclose(
197+
np.array([1.0, 0.0, 0.0, 3.0, 0.0, 0.0, 1.2]), x0, atol=ATOL, rtol=RTOL
198+
)
199+
200+
np.testing.assert_allclose(
201+
np.array(
202+
[
203+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
204+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
205+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
206+
[0.0, 0.0, 0.0, -1.0, -1.0, -1.0, -1.0],
207+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
208+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
209+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
210+
]
211+
),
212+
T,
213+
atol=ATOL,
214+
rtol=RTOL,
215+
)
216+
217+
53218
def get_shift_factor(s):
54219
s_str = str(s)
55220
if "." not in s_str:

0 commit comments

Comments
 (0)