Skip to content

Commit 7cae487

Browse files
Allow multiple observed in LevelTrend component
1 parent b970a6c commit 7cae487

File tree

3 files changed

+125
-20
lines changed

3 files changed

+125
-20
lines changed

pymc_extras/statespace/models/structural/components/level_trend.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
from scipy import linalg
4+
35
from pymc_extras.statespace.models.structural.core import Component
46
from pymc_extras.statespace.models.structural.utils import order_to_mask
57
from pymc_extras.statespace.utils.constants import POSITION_DERIVATIVE_NAMES
@@ -120,6 +122,7 @@ def __init__(
120122

121123
if observed_state_names is None:
122124
observed_state_names = ["data"]
125+
k_endog = len(observed_state_names)
123126

124127
self._order_mask = order_to_mask(order)
125128
max_state = np.flatnonzero(self._order_mask)[-1].item() + 1
@@ -148,49 +151,83 @@ def __init__(
148151

149152
super().__init__(
150153
name,
151-
k_endog=len(observed_state_names),
152-
k_states=k_states,
153-
k_posdef=k_posdef,
154+
k_endog=k_endog,
155+
k_states=k_states * k_endog,
156+
k_posdef=k_posdef * k_endog,
154157
observed_state_names=observed_state_names,
155158
measurement_error=False,
156159
combine_hidden_states=False,
157-
obs_state_idxs=np.array([1.0] + [0.0] * (k_states - 1)),
160+
obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
158161
)
159162

160163
def populate_component_properties(self):
161-
name_slice = POSITION_DERIVATIVE_NAMES[: self.k_states]
164+
k_endog = self.k_endog
165+
k_states = self.k_states // k_endog
166+
k_posdef = self.k_posdef // k_endog
167+
168+
name_slice = POSITION_DERIVATIVE_NAMES[:k_states]
162169
self.param_names = ["initial_trend"]
163170
self.state_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
164171
self.param_dims = {"initial_trend": ("trend_state",)}
165172
self.coords = {"trend_state": self.state_names}
166-
self.param_info = {"initial_trend": {"shape": (self.k_states,), "constraints": None}}
173+
174+
if k_endog > 1:
175+
self.param_dims["trend_state"] = (
176+
"trend_endog",
177+
"trend_state",
178+
)
179+
self.coords["trend_endog"] = self.observed_state_names
180+
181+
shape = (k_endog, k_states) if k_endog > 1 else (k_states,)
182+
self.param_info = {"initial_trend": {"shape": shape, "constraints": None}}
167183

168184
if self.k_posdef > 0:
169185
self.param_names += ["sigma_trend"]
170186
self.shock_names = [
171187
name for name, mask in zip(name_slice, self.innovations_order) if mask
172188
]
173-
self.param_dims["sigma_trend"] = ("trend_shock",)
189+
self.param_dims["sigma_trend"] = (
190+
("trend_shock",) if k_endog == 1 else ("trend_endog", "trend_shock")
191+
)
174192
self.coords["trend_shock"] = self.shock_names
175-
self.param_info["sigma_trend"] = {"shape": (self.k_posdef,), "constraints": "Positive"}
193+
self.param_info["sigma_trend"] = {
194+
"shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
195+
"constraints": "Positive",
196+
}
176197

177198
for name in self.param_names:
178199
self.param_info[name]["dims"] = self.param_dims[name]
179200

180201
def make_symbolic_graph(self) -> None:
181-
initial_trend = self.make_and_register_variable("initial_trend", shape=(self.k_states,))
182-
self.ssm["initial_state", :] = initial_trend
183-
triu_idx = np.triu_indices(self.k_states)
184-
self.ssm[np.s_["transition", triu_idx[0], triu_idx[1]]] = 1
202+
k_endog = self.k_endog
203+
k_states = self.k_states // k_endog
204+
k_posdef = self.k_posdef // k_endog
185205

186-
R = np.eye(self.k_states)
206+
initial_trend = self.make_and_register_variable(
207+
"initial_trend",
208+
shape=(k_states,) if k_endog == 1 else (k_endog, k_states),
209+
)
210+
self.ssm["initial_state", :] = initial_trend.ravel()
211+
212+
triu_idx = np.triu_indices(k_states)
213+
T = np.zeros((k_states, k_states))
214+
T[triu_idx[0], triu_idx[1]] = 1
215+
216+
self.ssm["transition"] = linalg.block_diag(*[T for _ in range(k_endog)])
217+
218+
R = np.eye(k_states)
187219
R = R[:, self.innovations_order]
188-
self.ssm["selection", :, :] = R
189220

190-
self.ssm["design", 0, :] = np.array([1.0] + [0.0] * (self.k_states - 1))
221+
self.ssm["selection", :, :] = linalg.block_diag(*[R for _ in range(k_endog)])
191222

192-
if self.k_posdef > 0:
193-
sigma_trend = self.make_and_register_variable("sigma_trend", shape=(self.k_posdef,))
194-
diag_idx = np.diag_indices(self.k_posdef)
223+
Z = np.array([1.0] + [0.0] * (k_states - 1)).reshape((1, -1))
224+
self.ssm["design"] = linalg.block_diag(*[Z for _ in range(k_endog)])
225+
226+
if k_posdef > 0:
227+
sigma_trend = self.make_and_register_variable(
228+
"sigma_trend",
229+
shape=(k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
230+
)
231+
diag_idx = np.diag_indices(k_posdef * k_endog)
195232
idx = np.s_["state_cov", diag_idx[0], diag_idx[1]]
196-
self.ssm[idx] = sigma_trend**2
233+
self.ssm[idx] = (sigma_trend**2).ravel()

tests/statespace/models/structural/components/test_level_trend.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,70 @@ def test_level_trend_model(rng):
2222
mod = mod.build(verbose=False)
2323
_assert_basic_coords_correct(mod)
2424
assert mod.coords["trend_state"] == ["level", "trend"]
25+
26+
27+
def test_level_trend_multiple_observed_construction():
28+
mod = st.LevelTrendComponent(
29+
order=2, innovations_order=1, observed_state_names=["data_1", "data_2", "data_3"]
30+
)
31+
mod = mod.build(verbose=False)
32+
assert mod.k_endog == 3
33+
assert mod.k_states == 6
34+
assert mod.k_posdef == 3
35+
36+
assert mod.coords["trend_state"] == ["level", "trend"]
37+
assert mod.coords["trend_endog"] == ["data_1", "data_2", "data_3"]
38+
39+
Z = mod.ssm["design"].eval()
40+
T = mod.ssm["transition"].eval()
41+
R = mod.ssm["selection"].eval()
42+
43+
np.testing.assert_allclose(
44+
Z,
45+
np.array(
46+
[
47+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
48+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
49+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
50+
]
51+
),
52+
)
53+
54+
np.testing.assert_allclose(
55+
T,
56+
np.array(
57+
[
58+
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
59+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
60+
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
61+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
62+
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
63+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
64+
]
65+
),
66+
)
67+
68+
np.testing.assert_allclose(
69+
R,
70+
np.array(
71+
[
72+
[1.0, 0.0, 0.0],
73+
[0.0, 0.0, 0.0],
74+
[0.0, 1.0, 0.0],
75+
[0.0, 0.0, 0.0],
76+
[0.0, 0.0, 1.0],
77+
[0.0, 0.0, 0.0],
78+
]
79+
),
80+
)
81+
82+
83+
def test_level_trend_multiple_observed(rng):
84+
mod = st.LevelTrendComponent(
85+
order=2, innovations_order=0, observed_state_names=["data_1", "data_2", "data_3"]
86+
)
87+
params = {"initial_trend": np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]])}
88+
89+
x, y = simulate_from_numpy_model(mod, rng, params)
90+
assert (np.diff(y, axis=0) == np.array([[1.0, 2.0, 3.0]])).all().all()
91+
assert (np.diff(x, axis=0) == np.array([[1.0, 0.0, 2.0, 0.0, 3.0, 0.0]])).all().all()

tests/statespace/test_utilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,12 @@ def simulate_from_numpy_model(mod, rng, param_dict, data_dict=None, steps=100):
242242
Helper function to visualize the components outside of a PyMC model context
243243
"""
244244
x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict, data_dict)
245+
k_endog = mod.k_endog
245246
k_states = mod.k_states
246247
k_posdef = mod.k_posdef
247248

248249
x = np.zeros((steps, k_states))
249-
y = np.zeros(steps)
250+
y = np.zeros((steps, k_endog))
250251

251252
x[0] = x0
252253
y[0] = (Z @ x0).squeeze() if Z.ndim == 2 else (Z[0] @ x0).squeeze()

0 commit comments

Comments
 (0)