|  | 
| 1 | 1 | import numpy as np | 
|  | 2 | +import pytensor.tensor as pt | 
| 2 | 3 | 
 | 
| 3 | 4 | from pymc_extras.statespace.models.structural.core import Component | 
| 4 | 5 | from pymc_extras.statespace.models.structural.utils import order_to_mask | 
| @@ -70,53 +71,109 @@ def __init__( | 
| 70 | 71 |         if observed_state_names is None: | 
| 71 | 72 |             observed_state_names = ["data"] | 
| 72 | 73 | 
 | 
|  | 74 | +        k_posdef = k_endog = len(observed_state_names) | 
|  | 75 | + | 
| 73 | 76 |         order = order_to_mask(order) | 
| 74 | 77 |         ar_lags = np.flatnonzero(order).ravel().astype(int) + 1 | 
| 75 | 78 |         k_states = len(order) | 
| 76 |  | -        k_posdef = k_endog = len(observed_state_names) | 
| 77 | 79 | 
 | 
| 78 | 80 |         self.order = order | 
| 79 | 81 |         self.ar_lags = ar_lags | 
| 80 | 82 | 
 | 
| 81 | 83 |         super().__init__( | 
| 82 | 84 |             name=name, | 
| 83 | 85 |             k_endog=k_endog, | 
| 84 |  | -            k_states=k_states, | 
|  | 86 | +            k_states=k_states * k_endog, | 
| 85 | 87 |             k_posdef=k_posdef, | 
| 86 | 88 |             measurement_error=True, | 
| 87 | 89 |             combine_hidden_states=True, | 
| 88 | 90 |             observed_state_names=observed_state_names, | 
| 89 |  | -            obs_state_idxs=np.r_[[1.0], np.zeros(k_states - 1)], | 
|  | 91 | +            obs_state_idxs=np.tile(np.r_[[1.0], np.zeros(k_states - 1)], k_endog), | 
| 90 | 92 |         ) | 
| 91 | 93 | 
 | 
| 92 | 94 |     def populate_component_properties(self): | 
| 93 |  | -        self.state_names = [f"L{i + 1}.data" for i in range(self.k_states)] | 
| 94 |  | -        self.shock_names = [f"{self.name}_innovation"] | 
|  | 95 | +        self.state_names = [ | 
|  | 96 | +            f"L{i + 1}.{state_name}" | 
|  | 97 | +            for i in range(self.k_states) | 
|  | 98 | +            for state_name in self.observed_state_names | 
|  | 99 | +        ] | 
|  | 100 | +        self.shock_names = [f"{name}_{self.name}_innovation" for name in self.observed_state_names] | 
| 95 | 101 |         self.param_names = ["ar_params", "sigma_ar"] | 
| 96 | 102 |         self.param_dims = {"ar_params": (AR_PARAM_DIM,)} | 
| 97 | 103 |         self.coords = {AR_PARAM_DIM: self.ar_lags.tolist()} | 
| 98 | 104 | 
 | 
|  | 105 | +        if self.k_endog > 1: | 
|  | 106 | +            self.param_dims["ar_params"] = ( | 
|  | 107 | +                f"{self.name}_endog", | 
|  | 108 | +                AR_PARAM_DIM, | 
|  | 109 | +            ) | 
|  | 110 | +            self.param_dims["sigma_ar"] = (f"{self.name}_endog",) | 
|  | 111 | + | 
|  | 112 | +            self.coords[f"{self.name}_endog"] = self.observed_state_names | 
|  | 113 | + | 
| 99 | 114 |         self.param_info = { | 
| 100 | 115 |             "ar_params": { | 
| 101 |  | -                "shape": (self.k_states,), | 
|  | 116 | +                "shape": (self.k_states,) if self.k_endog == 1 else (self.k_endog, self.k_states), | 
| 102 | 117 |                 "constraints": None, | 
| 103 |  | -                "dims": (AR_PARAM_DIM,), | 
|  | 118 | +                "dims": (AR_PARAM_DIM,) | 
|  | 119 | +                if self.k_endog == 1 | 
|  | 120 | +                else ( | 
|  | 121 | +                    f"{self.name}_endog", | 
|  | 122 | +                    AR_PARAM_DIM, | 
|  | 123 | +                ), | 
|  | 124 | +            }, | 
|  | 125 | +            "sigma_ar": { | 
|  | 126 | +                "shape": () if self.k_endog == 1 else (self.k_endog,), | 
|  | 127 | +                "constraints": "Positive", | 
|  | 128 | +                "dims": None if self.k_endog == 1 else (f"{self.name}_endog",), | 
| 104 | 129 |             }, | 
| 105 |  | -            "sigma_ar": {"shape": (), "constraints": "Positive", "dims": None}, | 
| 106 | 130 |         } | 
| 107 | 131 | 
 | 
| 108 | 132 |     def make_symbolic_graph(self) -> None: | 
|  | 133 | +        k_endog = self.k_endog | 
|  | 134 | +        k_states = self.k_states // k_endog | 
|  | 135 | +        k_posdef = self.k_posdef | 
|  | 136 | + | 
| 109 | 137 |         k_nonzero = int(sum(self.order)) | 
| 110 |  | -        ar_params = self.make_and_register_variable("ar_params", shape=(k_nonzero,)) | 
| 111 |  | -        sigma_ar = self.make_and_register_variable("sigma_ar", shape=()) | 
|  | 138 | +        ar_params = self.make_and_register_variable( | 
|  | 139 | +            "ar_params", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero) | 
|  | 140 | +        ) | 
|  | 141 | +        sigma_ar = self.make_and_register_variable( | 
|  | 142 | +            "sigma_ar", shape=() if k_endog == 1 else (k_endog,) | 
|  | 143 | +        ) | 
|  | 144 | + | 
|  | 145 | +        if k_endog == 1: | 
|  | 146 | +            T = pt.eye(k_states, k=-1) | 
|  | 147 | +            ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0]) | 
|  | 148 | +            T = T[ar_idx].set(ar_params) | 
|  | 149 | + | 
|  | 150 | +        else: | 
|  | 151 | +            transition_matrices = [] | 
|  | 152 | + | 
|  | 153 | +            for i in range(k_endog): | 
|  | 154 | +                T = pt.eye(k_states, k=-1) | 
|  | 155 | +                ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0]) | 
|  | 156 | +                T = T[ar_idx].set(ar_params[i]) | 
|  | 157 | +                transition_matrices.append(T) | 
|  | 158 | +            T = pt.specify_shape( | 
|  | 159 | +                pt.linalg.block_diag(*transition_matrices), (self.k_states, self.k_states) | 
|  | 160 | +            ) | 
| 112 | 161 | 
 | 
| 113 |  | -        T = np.eye(self.k_states, k=-1) | 
| 114 | 162 |         self.ssm["transition", :, :] = T | 
| 115 |  | -        self.ssm["selection", 0, 0] = 1 | 
| 116 |  | -        self.ssm["design", 0, 0] = 1 | 
| 117 | 163 | 
 | 
| 118 |  | -        ar_idx = ("transition", np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0]) | 
| 119 |  | -        self.ssm[ar_idx] = ar_params | 
|  | 164 | +        R = np.eye(k_states) | 
|  | 165 | +        R_mask = np.full((k_states), False) | 
|  | 166 | +        R_mask[0] = True | 
|  | 167 | +        R = R[:, R_mask] | 
|  | 168 | + | 
|  | 169 | +        self.ssm["selection", :, :] = pt.specify_shape( | 
|  | 170 | +            pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef) | 
|  | 171 | +        ) | 
|  | 172 | + | 
|  | 173 | +        Z = pt.zeros((1, k_states))[0, 0].set(1.0) | 
|  | 174 | +        self.ssm["design", :, :] = pt.specify_shape( | 
|  | 175 | +            pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states) | 
|  | 176 | +        ) | 
| 120 | 177 | 
 | 
| 121 |  | -        cov_idx = ("state_cov", *np.diag_indices(1)) | 
|  | 178 | +        cov_idx = ("state_cov", *np.diag_indices(k_posdef)) | 
| 122 | 179 |         self.ssm[cov_idx] = sigma_ar**2 | 
0 commit comments