|
15 | 15 | split_vars_into_seq_and_nonseq, |
16 | 16 | stabilize, |
17 | 17 | ) |
18 | | -from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL |
| 18 | +from pymc_extras.statespace.utils.constants import ( |
| 19 | + FILTER_OUTPUT_NAMES, |
| 20 | + JITTER_DEFAULT, |
| 21 | + MATRIX_NAMES, |
| 22 | + MISSING_FILL, |
| 23 | +) |
19 | 24 |
|
20 | 25 | MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) |
21 | | -PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] |
| 26 | +PARAM_NAMES = MATRIX_NAMES[2:] |
22 | 27 |
|
23 | 28 | assert_time_varying_dim_correct = Assert( |
24 | 29 | "The first dimension of a time varying matrix (the time dimension) must be " |
@@ -119,7 +124,7 @@ def unpack_args(self, args) -> tuple: |
119 | 124 | # There are always two outputs_info wedged between the seqs and non_seqs |
120 | 125 | seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :] |
121 | 126 | return_ordered = [] |
122 | | - for name in ["c", "d", "T", "Z", "R", "H", "Q"]: |
| 127 | + for name in PARAM_NAMES: |
123 | 128 | if name in self.seq_names: |
124 | 129 | idx = self.seq_names.index(name) |
125 | 130 | return_ordered.append(seqs[idx]) |
@@ -253,28 +258,28 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: |
253 | 258 | ) |
254 | 259 |
|
255 | 260 | filtered_states = pt.specify_shape(filtered_states, (n, self.n_states)) |
256 | | - filtered_states.name = "filtered_states" |
| 261 | + filtered_states.name = FILTER_OUTPUT_NAMES[0] |
257 | 262 |
|
258 | 263 | predicted_states = pt.specify_shape(predicted_states, (n, self.n_states)) |
259 | | - predicted_states.name = "predicted_states" |
260 | | - |
261 | | - observed_states = pt.specify_shape(observed_states, (n, self.n_endog)) |
262 | | - observed_states.name = "observed_states" |
| 264 | + predicted_states.name = FILTER_OUTPUT_NAMES[1] |
263 | 265 |
|
264 | 266 | filtered_covariances = pt.specify_shape( |
265 | 267 | filtered_covariances, (n, self.n_states, self.n_states) |
266 | 268 | ) |
267 | | - filtered_covariances.name = "filtered_covariances" |
| 269 | + filtered_covariances.name = FILTER_OUTPUT_NAMES[2] |
268 | 270 |
|
269 | 271 | predicted_covariances = pt.specify_shape( |
270 | 272 | predicted_covariances, (n, self.n_states, self.n_states) |
271 | 273 | ) |
272 | | - predicted_covariances.name = "predicted_covariances" |
| 274 | + predicted_covariances.name = FILTER_OUTPUT_NAMES[3] |
| 275 | + |
| 276 | + observed_states = pt.specify_shape(observed_states, (n, self.n_endog)) |
| 277 | + observed_states.name = FILTER_OUTPUT_NAMES[4] |
273 | 278 |
|
274 | 279 | observed_covariances = pt.specify_shape( |
275 | 280 | observed_covariances, (n, self.n_endog, self.n_endog) |
276 | 281 | ) |
277 | | - observed_covariances.name = "observed_covariances" |
| 282 | + observed_covariances.name = FILTER_OUTPUT_NAMES[5] |
278 | 283 |
|
279 | 284 | loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,)) |
280 | 285 | loglike_obs.name = "loglike_obs" |
|
0 commit comments