Skip to content

Commit 0a84576

Browse files
Allow multiple observed in AutoRegressive component
1 parent bba8431 commit 0a84576

File tree

2 files changed

+92
-16
lines changed

2 files changed

+92
-16
lines changed

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytensor.tensor as pt
23

34
from pymc_extras.statespace.models.structural.core import Component
45
from pymc_extras.statespace.models.structural.utils import order_to_mask
@@ -70,53 +71,109 @@ def __init__(
7071
if observed_state_names is None:
7172
observed_state_names = ["data"]
7273

74+
k_posdef = k_endog = len(observed_state_names)
75+
7376
order = order_to_mask(order)
7477
ar_lags = np.flatnonzero(order).ravel().astype(int) + 1
7578
k_states = len(order)
76-
k_posdef = k_endog = len(observed_state_names)
7779

7880
self.order = order
7981
self.ar_lags = ar_lags
8082

8183
super().__init__(
8284
name=name,
8385
k_endog=k_endog,
84-
k_states=k_states,
86+
k_states=k_states * k_endog,
8587
k_posdef=k_posdef,
8688
measurement_error=True,
8789
combine_hidden_states=True,
8890
observed_state_names=observed_state_names,
89-
obs_state_idxs=np.r_[[1.0], np.zeros(k_states - 1)],
91+
obs_state_idxs=np.tile(np.r_[[1.0], np.zeros(k_states - 1)], k_endog),
9092
)
9193

9294
def populate_component_properties(self):
93-
self.state_names = [f"L{i + 1}.data" for i in range(self.k_states)]
94-
self.shock_names = [f"{self.name}_innovation"]
95+
self.state_names = [
96+
f"L{i + 1}.{state_name}"
97+
for i in range(self.k_states)
98+
for state_name in self.observed_state_names
99+
]
100+
self.shock_names = [f"{name}_{self.name}_innovation" for name in self.observed_state_names]
95101
self.param_names = ["ar_params", "sigma_ar"]
96102
self.param_dims = {"ar_params": (AR_PARAM_DIM,)}
97103
self.coords = {AR_PARAM_DIM: self.ar_lags.tolist()}
98104

105+
if self.k_endog > 1:
106+
self.param_dims["ar_params"] = (
107+
f"{self.name}_endog",
108+
AR_PARAM_DIM,
109+
)
110+
self.param_dims["sigma_ar"] = (f"{self.name}_endog",)
111+
112+
self.coords[f"{self.name}_endog"] = self.observed_state_names
113+
99114
self.param_info = {
100115
"ar_params": {
101-
"shape": (self.k_states,),
116+
"shape": (self.k_states,) if self.k_endog == 1 else (self.k_endog, self.k_states),
102117
"constraints": None,
103-
"dims": (AR_PARAM_DIM,),
118+
"dims": (AR_PARAM_DIM,)
119+
if self.k_endog == 1
120+
else (
121+
f"{self.name}_endog",
122+
AR_PARAM_DIM,
123+
),
124+
},
125+
"sigma_ar": {
126+
"shape": () if self.k_endog == 1 else (self.k_endog,),
127+
"constraints": "Positive",
128+
"dims": None if self.k_endog == 1 else (f"{self.name}_endog",),
104129
},
105-
"sigma_ar": {"shape": (), "constraints": "Positive", "dims": None},
106130
}
107131

108132
def make_symbolic_graph(self) -> None:
133+
k_endog = self.k_endog
134+
k_states = self.k_states // k_endog
135+
k_posdef = self.k_posdef
136+
109137
k_nonzero = int(sum(self.order))
110-
ar_params = self.make_and_register_variable("ar_params", shape=(k_nonzero,))
111-
sigma_ar = self.make_and_register_variable("sigma_ar", shape=())
138+
ar_params = self.make_and_register_variable(
139+
"ar_params", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
140+
)
141+
sigma_ar = self.make_and_register_variable(
142+
"sigma_ar", shape=() if k_endog == 1 else (k_endog,)
143+
)
144+
145+
if k_endog == 1:
146+
T = pt.eye(k_states, k=-1)
147+
ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
148+
T = T[ar_idx].set(ar_params)
149+
150+
else:
151+
transition_matrices = []
152+
153+
for i in range(k_endog):
154+
T = pt.eye(k_states, k=-1)
155+
ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
156+
T = T[ar_idx].set(ar_params[i])
157+
transition_matrices.append(T)
158+
T = pt.specify_shape(
159+
pt.linalg.block_diag(*transition_matrices), (self.k_states, self.k_states)
160+
)
112161

113-
T = np.eye(self.k_states, k=-1)
114162
self.ssm["transition", :, :] = T
115-
self.ssm["selection", 0, 0] = 1
116-
self.ssm["design", 0, 0] = 1
117163

118-
ar_idx = ("transition", np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
119-
self.ssm[ar_idx] = ar_params
164+
R = np.eye(k_states)
165+
R_mask = np.full((k_states), False)
166+
R_mask[0] = True
167+
R = R[:, R_mask]
168+
169+
self.ssm["selection", :, :] = pt.specify_shape(
170+
pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
171+
)
172+
173+
Z = pt.zeros((1, k_states))[0, 0].set(1.0)
174+
self.ssm["design", :, :] = pt.specify_shape(
175+
pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
176+
)
120177

121-
cov_idx = ("state_cov", *np.diag_indices(1))
178+
cov_idx = ("state_cov", *np.diag_indices(k_posdef))
122179
self.ssm[cov_idx] = sigma_ar**2

tests/statespace/models/structural/components/test_autoregressive.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,22 @@ def test_autoregressive_model(order, rng):
2626
if isinstance(order, list):
2727
lags = lags[np.flatnonzero(order)]
2828
assert_allclose(ar.coords["ar_lag"], lags)
29+
30+
31+
def test_autoregressive_multiple_observed(rng):
32+
ar = st.AutoregressiveComponent(order=3, observed_state_names=["data_1", "data_2"])
33+
mod = ar.build(verbose=False)
34+
35+
params = {
36+
"ar_params": np.full(
37+
(
38+
2,
39+
sum(ar.order),
40+
),
41+
0.5,
42+
dtype=config.floatX,
43+
),
44+
"sigma_ar": np.ones((2,)) * 1e-3,
45+
}
46+
47+
x, y = simulate_from_numpy_model(ar, rng, params, steps=100)

0 commit comments

Comments
 (0)