|
1 | 1 | import numpy as np
|
2 | 2 |
|
| 3 | +from scipy import linalg |
| 4 | + |
3 | 5 | from pymc_extras.statespace.models.structural.core import Component
|
4 | 6 | from pymc_extras.statespace.models.structural.utils import order_to_mask
|
5 | 7 | from pymc_extras.statespace.utils.constants import POSITION_DERIVATIVE_NAMES
|
@@ -120,6 +122,7 @@ def __init__(
|
120 | 122 |
|
121 | 123 | if observed_state_names is None:
|
122 | 124 | observed_state_names = ["data"]
|
| 125 | + k_endog = len(observed_state_names) |
123 | 126 |
|
124 | 127 | self._order_mask = order_to_mask(order)
|
125 | 128 | max_state = np.flatnonzero(self._order_mask)[-1].item() + 1
|
@@ -148,49 +151,83 @@ def __init__(
|
148 | 151 |
|
149 | 152 | super().__init__(
|
150 | 153 | name,
|
151 |
| - k_endog=len(observed_state_names), |
152 |
| - k_states=k_states, |
153 |
| - k_posdef=k_posdef, |
| 154 | + k_endog=k_endog, |
| 155 | + k_states=k_states * k_endog, |
| 156 | + k_posdef=k_posdef * k_endog, |
154 | 157 | observed_state_names=observed_state_names,
|
155 | 158 | measurement_error=False,
|
156 | 159 | combine_hidden_states=False,
|
157 |
| - obs_state_idxs=np.array([1.0] + [0.0] * (k_states - 1)), |
| 160 | + obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog), |
158 | 161 | )
|
159 | 162 |
|
160 | 163 | def populate_component_properties(self):
|
161 |
| - name_slice = POSITION_DERIVATIVE_NAMES[: self.k_states] |
| 164 | + k_endog = self.k_endog |
| 165 | + k_states = self.k_states // k_endog |
| 166 | + k_posdef = self.k_posdef // k_endog |
| 167 | + |
| 168 | + name_slice = POSITION_DERIVATIVE_NAMES[:k_states] |
162 | 169 | self.param_names = ["initial_trend"]
|
163 | 170 | self.state_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
|
164 | 171 | self.param_dims = {"initial_trend": ("trend_state",)}
|
165 | 172 | self.coords = {"trend_state": self.state_names}
|
166 |
| - self.param_info = {"initial_trend": {"shape": (self.k_states,), "constraints": None}} |
| 173 | + |
| 174 | + if k_endog > 1: |
| 175 | + self.param_dims["trend_state"] = ( |
| 176 | + "trend_endog", |
| 177 | + "trend_state", |
| 178 | + ) |
| 179 | + self.coords["trend_endog"] = self.observed_state_names |
| 180 | + |
| 181 | + shape = (k_endog, k_states) if k_endog > 1 else (k_states,) |
| 182 | + self.param_info = {"initial_trend": {"shape": shape, "constraints": None}} |
167 | 183 |
|
168 | 184 | if self.k_posdef > 0:
|
169 | 185 | self.param_names += ["sigma_trend"]
|
170 | 186 | self.shock_names = [
|
171 | 187 | name for name, mask in zip(name_slice, self.innovations_order) if mask
|
172 | 188 | ]
|
173 |
| - self.param_dims["sigma_trend"] = ("trend_shock",) |
| 189 | + self.param_dims["sigma_trend"] = ( |
| 190 | + ("trend_shock",) if k_endog == 1 else ("trend_endog", "trend_shock") |
| 191 | + ) |
174 | 192 | self.coords["trend_shock"] = self.shock_names
|
175 |
| - self.param_info["sigma_trend"] = {"shape": (self.k_posdef,), "constraints": "Positive"} |
| 193 | + self.param_info["sigma_trend"] = { |
| 194 | + "shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef), |
| 195 | + "constraints": "Positive", |
| 196 | + } |
176 | 197 |
|
177 | 198 | for name in self.param_names:
|
178 | 199 | self.param_info[name]["dims"] = self.param_dims[name]
|
179 | 200 |
|
180 | 201 | def make_symbolic_graph(self) -> None:
|
181 |
| - initial_trend = self.make_and_register_variable("initial_trend", shape=(self.k_states,)) |
182 |
| - self.ssm["initial_state", :] = initial_trend |
183 |
| - triu_idx = np.triu_indices(self.k_states) |
184 |
| - self.ssm[np.s_["transition", triu_idx[0], triu_idx[1]]] = 1 |
| 202 | + k_endog = self.k_endog |
| 203 | + k_states = self.k_states // k_endog |
| 204 | + k_posdef = self.k_posdef // k_endog |
185 | 205 |
|
186 |
| - R = np.eye(self.k_states) |
| 206 | + initial_trend = self.make_and_register_variable( |
| 207 | + "initial_trend", |
| 208 | + shape=(k_states,) if k_endog == 1 else (k_endog, k_states), |
| 209 | + ) |
| 210 | + self.ssm["initial_state", :] = initial_trend.ravel() |
| 211 | + |
| 212 | + triu_idx = np.triu_indices(k_states) |
| 213 | + T = np.zeros((k_states, k_states)) |
| 214 | + T[triu_idx[0], triu_idx[1]] = 1 |
| 215 | + |
| 216 | + self.ssm["transition"] = linalg.block_diag(*[T for _ in range(k_endog)]) |
| 217 | + |
| 218 | + R = np.eye(k_states) |
187 | 219 | R = R[:, self.innovations_order]
|
188 |
| - self.ssm["selection", :, :] = R |
189 | 220 |
|
190 |
| - self.ssm["design", 0, :] = np.array([1.0] + [0.0] * (self.k_states - 1)) |
| 221 | + self.ssm["selection", :, :] = linalg.block_diag(*[R for _ in range(k_endog)]) |
191 | 222 |
|
192 |
| - if self.k_posdef > 0: |
193 |
| - sigma_trend = self.make_and_register_variable("sigma_trend", shape=(self.k_posdef,)) |
194 |
| - diag_idx = np.diag_indices(self.k_posdef) |
| 223 | + Z = np.array([1.0] + [0.0] * (k_states - 1)).reshape((1, -1)) |
| 224 | + self.ssm["design"] = linalg.block_diag(*[Z for _ in range(k_endog)]) |
| 225 | + |
| 226 | + if k_posdef > 0: |
| 227 | + sigma_trend = self.make_and_register_variable( |
| 228 | + "sigma_trend", |
| 229 | + shape=(k_posdef,) if k_endog == 1 else (k_endog, k_posdef), |
| 230 | + ) |
| 231 | + diag_idx = np.diag_indices(k_posdef * k_endog) |
195 | 232 | idx = np.s_["state_cov", diag_idx[0], diag_idx[1]]
|
196 |
| - self.ssm[idx] = sigma_trend**2 |
| 233 | + self.ssm[idx] = (sigma_trend**2).ravel() |
0 commit comments