|  | 
| 1 | 1 | import numpy as np | 
| 2 | 2 | 
 | 
|  | 3 | +from scipy import linalg | 
|  | 4 | + | 
| 3 | 5 | from pymc_extras.statespace.models.structural.core import Component | 
| 4 | 6 | from pymc_extras.statespace.models.structural.utils import order_to_mask | 
| 5 | 7 | from pymc_extras.statespace.utils.constants import POSITION_DERIVATIVE_NAMES | 
| @@ -120,6 +122,7 @@ def __init__( | 
| 120 | 122 | 
 | 
| 121 | 123 |         if observed_state_names is None: | 
| 122 | 124 |             observed_state_names = ["data"] | 
|  | 125 | +        k_endog = len(observed_state_names) | 
| 123 | 126 | 
 | 
| 124 | 127 |         self._order_mask = order_to_mask(order) | 
| 125 | 128 |         max_state = np.flatnonzero(self._order_mask)[-1].item() + 1 | 
| @@ -148,49 +151,83 @@ def __init__( | 
| 148 | 151 | 
 | 
| 149 | 152 |         super().__init__( | 
| 150 | 153 |             name, | 
| 151 |  | -            k_endog=len(observed_state_names), | 
| 152 |  | -            k_states=k_states, | 
| 153 |  | -            k_posdef=k_posdef, | 
|  | 154 | +            k_endog=k_endog, | 
|  | 155 | +            k_states=k_states * k_endog, | 
|  | 156 | +            k_posdef=k_posdef * k_endog, | 
| 154 | 157 |             observed_state_names=observed_state_names, | 
| 155 | 158 |             measurement_error=False, | 
| 156 | 159 |             combine_hidden_states=False, | 
| 157 |  | -            obs_state_idxs=np.array([1.0] + [0.0] * (k_states - 1)), | 
|  | 160 | +            obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog), | 
| 158 | 161 |         ) | 
| 159 | 162 | 
 | 
| 160 | 163 |     def populate_component_properties(self): | 
| 161 |  | -        name_slice = POSITION_DERIVATIVE_NAMES[: self.k_states] | 
|  | 164 | +        k_endog = self.k_endog | 
|  | 165 | +        k_states = self.k_states // k_endog | 
|  | 166 | +        k_posdef = self.k_posdef // k_endog | 
|  | 167 | + | 
|  | 168 | +        name_slice = POSITION_DERIVATIVE_NAMES[:k_states] | 
| 162 | 169 |         self.param_names = ["initial_trend"] | 
| 163 | 170 |         self.state_names = [name for name, mask in zip(name_slice, self._order_mask) if mask] | 
| 164 | 171 |         self.param_dims = {"initial_trend": ("trend_state",)} | 
| 165 | 172 |         self.coords = {"trend_state": self.state_names} | 
| 166 |  | -        self.param_info = {"initial_trend": {"shape": (self.k_states,), "constraints": None}} | 
|  | 173 | + | 
|  | 174 | +        if k_endog > 1: | 
|  | 175 | +            self.param_dims["trend_state"] = ( | 
|  | 176 | +                "trend_endog", | 
|  | 177 | +                "trend_state", | 
|  | 178 | +            ) | 
|  | 179 | +            self.coords["trend_endog"] = self.observed_state_names | 
|  | 180 | + | 
|  | 181 | +        shape = (k_endog, k_states) if k_endog > 1 else (k_states,) | 
|  | 182 | +        self.param_info = {"initial_trend": {"shape": shape, "constraints": None}} | 
| 167 | 183 | 
 | 
| 168 | 184 |         if self.k_posdef > 0: | 
| 169 | 185 |             self.param_names += ["sigma_trend"] | 
| 170 | 186 |             self.shock_names = [ | 
| 171 | 187 |                 name for name, mask in zip(name_slice, self.innovations_order) if mask | 
| 172 | 188 |             ] | 
| 173 |  | -            self.param_dims["sigma_trend"] = ("trend_shock",) | 
|  | 189 | +            self.param_dims["sigma_trend"] = ( | 
|  | 190 | +                ("trend_shock",) if k_endog == 1 else ("trend_endog", "trend_shock") | 
|  | 191 | +            ) | 
| 174 | 192 |             self.coords["trend_shock"] = self.shock_names | 
| 175 |  | -            self.param_info["sigma_trend"] = {"shape": (self.k_posdef,), "constraints": "Positive"} | 
|  | 193 | +            self.param_info["sigma_trend"] = { | 
|  | 194 | +                "shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef), | 
|  | 195 | +                "constraints": "Positive", | 
|  | 196 | +            } | 
| 176 | 197 | 
 | 
| 177 | 198 |         for name in self.param_names: | 
| 178 | 199 |             self.param_info[name]["dims"] = self.param_dims[name] | 
| 179 | 200 | 
 | 
| 180 | 201 |     def make_symbolic_graph(self) -> None: | 
| 181 |  | -        initial_trend = self.make_and_register_variable("initial_trend", shape=(self.k_states,)) | 
| 182 |  | -        self.ssm["initial_state", :] = initial_trend | 
| 183 |  | -        triu_idx = np.triu_indices(self.k_states) | 
| 184 |  | -        self.ssm[np.s_["transition", triu_idx[0], triu_idx[1]]] = 1 | 
|  | 202 | +        k_endog = self.k_endog | 
|  | 203 | +        k_states = self.k_states // k_endog | 
|  | 204 | +        k_posdef = self.k_posdef // k_endog | 
| 185 | 205 | 
 | 
| 186 |  | -        R = np.eye(self.k_states) | 
|  | 206 | +        initial_trend = self.make_and_register_variable( | 
|  | 207 | +            "initial_trend", | 
|  | 208 | +            shape=(k_states,) if k_endog == 1 else (k_endog, k_states), | 
|  | 209 | +        ) | 
|  | 210 | +        self.ssm["initial_state", :] = initial_trend.ravel() | 
|  | 211 | + | 
|  | 212 | +        triu_idx = np.triu_indices(k_states) | 
|  | 213 | +        T = np.zeros((k_states, k_states)) | 
|  | 214 | +        T[triu_idx[0], triu_idx[1]] = 1 | 
|  | 215 | + | 
|  | 216 | +        self.ssm["transition"] = linalg.block_diag(*[T for _ in range(k_endog)]) | 
|  | 217 | + | 
|  | 218 | +        R = np.eye(k_states) | 
| 187 | 219 |         R = R[:, self.innovations_order] | 
| 188 |  | -        self.ssm["selection", :, :] = R | 
| 189 | 220 | 
 | 
| 190 |  | -        self.ssm["design", 0, :] = np.array([1.0] + [0.0] * (self.k_states - 1)) | 
|  | 221 | +        self.ssm["selection", :, :] = linalg.block_diag(*[R for _ in range(k_endog)]) | 
| 191 | 222 | 
 | 
| 192 |  | -        if self.k_posdef > 0: | 
| 193 |  | -            sigma_trend = self.make_and_register_variable("sigma_trend", shape=(self.k_posdef,)) | 
| 194 |  | -            diag_idx = np.diag_indices(self.k_posdef) | 
|  | 223 | +        Z = np.array([1.0] + [0.0] * (k_states - 1)).reshape((1, -1)) | 
|  | 224 | +        self.ssm["design"] = linalg.block_diag(*[Z for _ in range(k_endog)]) | 
|  | 225 | + | 
|  | 226 | +        if k_posdef > 0: | 
|  | 227 | +            sigma_trend = self.make_and_register_variable( | 
|  | 228 | +                "sigma_trend", | 
|  | 229 | +                shape=(k_posdef,) if k_endog == 1 else (k_endog, k_posdef), | 
|  | 230 | +            ) | 
|  | 231 | +            diag_idx = np.diag_indices(k_posdef * k_endog) | 
| 195 | 232 |             idx = np.s_["state_cov", diag_idx[0], diag_idx[1]] | 
| 196 |  | -            self.ssm[idx] = sigma_trend**2 | 
|  | 233 | +            self.ssm[idx] = (sigma_trend**2).ravel() | 
0 commit comments