Skip to content

Commit 9e78bae

Browse files
committed
cleaned up commented code, moved internal checks to the top, reduced intermediate variables
1 parent d142a91 commit 9e78bae

File tree

1 file changed

+23
-33
lines changed

1 file changed

+23
-33
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,6 +1684,21 @@ def sample_filter_outputs(
16841684
if isinstance(filter_output_names, str):
16851685
filter_output_names = [filter_output_names]
16861686

1687+
drop_keys = {"predicted_observed_states", "predicted_observed_covariances"}
1688+
all_filter_output_dims = {k: v for k, v in FILTER_OUTPUT_DIMS.items() if k not in drop_keys}
1689+
1690+
if filter_output_names is None:
1691+
filter_output_names = list(all_filter_output_dims.keys())
1692+
else:
1693+
unknown_filter_output_names = np.setdiff1d(
1694+
filter_output_names, list(all_filter_output_dims.keys())
1695+
)
1696+
if unknown_filter_output_names.size > 0:
1697+
raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!")
1698+
filter_output_names = [
1699+
x for x in all_filter_output_dims.keys() if x in filter_output_names
1700+
]
1701+
16871702
compile_kwargs = kwargs.pop("compile_kwargs", {})
16881703
compile_kwargs.setdefault("mode", self.mode)
16891704

@@ -1726,44 +1741,19 @@ def sample_filter_outputs(
17261741
T, R, Q, filter_outputs[0], filter_outputs[3]
17271742
)
17281743

1729-
# Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph()
1730-
# filter_output_dims_mapping = {}
1731-
# for k in FILTER_OUTPUT_DIMS.keys():
1732-
# filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k]
1733-
1734-
all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
1735-
# This excludes observed states and observed covariances from the filter outputs
1736-
all_filter_outputs = [
1737-
output for output in all_filter_outputs if output.name in FILTER_OUTPUT_DIMS
1738-
]
1744+
filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
1745+
for output in filter_outputs:
1746+
if output.name in filter_output_names:
1747+
dims = all_filter_output_dims[output.name]
1748+
pm.Deterministic(output.name, output, dims=dims)
17391749

1740-
if filter_output_names is None:
1741-
filter_output_names = all_filter_outputs
1742-
else:
1743-
unknown_filter_output_names = np.setdiff1d(
1744-
filter_output_names, [x.name for x in all_filter_outputs]
1745-
)
1746-
if unknown_filter_output_names.size > 0:
1747-
raise ValueError(
1748-
f"{unknown_filter_output_names} not a valid filter output name!"
1749-
)
1750-
filter_output_names = [
1751-
x for x in all_filter_outputs if x.name in filter_output_names
1752-
]
1753-
1754-
for output in filter_output_names:
1755-
dims = FILTER_OUTPUT_DIMS[output.name]
1756-
pm.Deterministic(output.name, output, dims=dims)
1757-
1758-
frozen_model = freeze_dims_and_data(m)
1759-
with frozen_model:
1760-
idata_filter = pm.sample_posterior_predictive(
1750+
with freeze_dims_and_data(m):
1751+
return pm.sample_posterior_predictive(
17611752
idata if group == "posterior" else idata.prior,
1762-
var_names=[x.name for x in frozen_model.deterministics],
1753+
var_names=filter_output_names,
17631754
compile_kwargs=compile_kwargs,
17641755
**kwargs,
17651756
)
1766-
return idata_filter
17671757

17681758
@staticmethod
17691759
def _validate_forecast_args(

0 commit comments

Comments
 (0)