Skip to content

Commit 46149ac

Browse files
committed
updated kalman filter outputs to use names defined in constants.py, updated sample_filter_outputs to allow sampling any filter outputs defined in constants.py
1 parent 9e78bae commit 46149ac

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,20 +1684,15 @@ def sample_filter_outputs(
16841684
if isinstance(filter_output_names, str):
16851685
filter_output_names = [filter_output_names]
16861686

1687-
drop_keys = {"predicted_observed_states", "predicted_observed_covariances"}
1688-
all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys}
1689-
16901687
if filter_output_names is None:
1691-
filter_output_names = list(all_filter_output_dims.keys())
1688+
filter_output_names = list(FILTER_OUTPUT_DIMS.keys())
16921689
else:
16931690
unknown_filter_output_names = np.setdiff1d(
1694-
filter_output_names, list(all_filter_output_dims.keys())
1691+
filter_output_names, list(FILTER_OUTPUT_DIMS.keys())
16951692
)
16961693
if unknown_filter_output_names.size > 0:
16971694
raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!")
1698-
filter_output_names = [
1699-
x for x in all_filter_output_dims.keys() if x in filter_output_names
1700-
]
1695+
filter_output_names = [x for x in FILTER_OUTPUT_DIMS.keys() if x in filter_output_names]
17011696

17021697
compile_kwargs = kwargs.pop("compile_kwargs", {})
17031698
compile_kwargs.setdefault("mode", self.mode)
@@ -1744,7 +1739,7 @@ def sample_filter_outputs(
17441739
filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
17451740
for output in filter_outputs:
17461741
if output.name in filter_output_names:
1747-
dims = all_filter_output_dims[output.name]
1742+
dims = FILTER_OUTPUT_DIMS[output.name]
17481743
pm.Deterministic(output.name, output, dims=dims)
17491744

17501745
with freeze_dims_and_data(m):

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
split_vars_into_seq_and_nonseq,
1616
stabilize,
1717
)
18-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
18+
from pymc_extras.statespace.utils.constants import (
19+
FILTER_OUTPUT_NAMES,
20+
JITTER_DEFAULT,
21+
MATRIX_NAMES,
22+
MISSING_FILL,
23+
)
1924

2025
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
21-
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
26+
PARAM_NAMES = MATRIX_NAMES[2:]
2227

2328
assert_time_varying_dim_correct = Assert(
2429
"The first dimension of a time varying matrix (the time dimension) must be "
@@ -119,7 +124,7 @@ def unpack_args(self, args) -> tuple:
119124
# There are always two outputs_info wedged between the seqs and non_seqs
120125
seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :]
121126
return_ordered = []
122-
for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
127+
for name in PARAM_NAMES:
123128
if name in self.seq_names:
124129
idx = self.seq_names.index(name)
125130
return_ordered.append(seqs[idx])
@@ -253,28 +258,28 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
253258
)
254259

255260
filtered_states = pt.specify_shape(filtered_states, (n, self.n_states))
256-
filtered_states.name = "filtered_states"
261+
filtered_states.name = FILTER_OUTPUT_NAMES[0]
257262

258263
predicted_states = pt.specify_shape(predicted_states, (n, self.n_states))
259-
predicted_states.name = "predicted_states"
260-
261-
observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
262-
observed_states.name = "observed_states"
264+
predicted_states.name = FILTER_OUTPUT_NAMES[1]
263265

264266
filtered_covariances = pt.specify_shape(
265267
filtered_covariances, (n, self.n_states, self.n_states)
266268
)
267-
filtered_covariances.name = "filtered_covariances"
269+
filtered_covariances.name = FILTER_OUTPUT_NAMES[2]
268270

269271
predicted_covariances = pt.specify_shape(
270272
predicted_covariances, (n, self.n_states, self.n_states)
271273
)
272-
predicted_covariances.name = "predicted_covariances"
274+
predicted_covariances.name = FILTER_OUTPUT_NAMES[3]
275+
276+
observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
277+
observed_states.name = FILTER_OUTPUT_NAMES[4]
273278

274279
observed_covariances = pt.specify_shape(
275280
observed_covariances, (n, self.n_endog, self.n_endog)
276281
)
277-
observed_covariances.name = "observed_covariances"
282+
observed_covariances.name = FILTER_OUTPUT_NAMES[5]
278283

279284
loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,))
280285
loglike_obs.name = "loglike_obs"

pymc_extras/statespace/utils/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
"predicted_states",
4343
"filtered_covariances",
4444
"predicted_covariances",
45+
"predicted_observed_states",
46+
"predicted_observed_covariances",
4547
]
4648

4749
SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"]

0 commit comments

Comments
 (0)