@@ -1684,6 +1684,21 @@ def sample_filter_outputs(
1684
1684
if isinstance (filter_output_names , str ):
1685
1685
filter_output_names = [filter_output_names ]
1686
1686
1687
+ drop_keys = {"predicted_observed_states" , "predicted_observed_covariances" }
1688
+ all_filter_output_dims = {k : v for k , v in FILTER_OUTPUT_DIMS .items () if k not in drop_keys }
1689
+
1690
+ if filter_output_names is None :
1691
+ filter_output_names = list (all_filter_output_dims .keys ())
1692
+ else :
1693
+ unknown_filter_output_names = np .setdiff1d (
1694
+ filter_output_names , list (all_filter_output_dims .keys ())
1695
+ )
1696
+ if unknown_filter_output_names .size > 0 :
1697
+ raise ValueError (f"{ unknown_filter_output_names } not a valid filter output name!" )
1698
+ filter_output_names = [
1699
+ x for x in all_filter_output_dims .keys () if x in filter_output_names
1700
+ ]
1701
+
1687
1702
compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1688
1703
compile_kwargs .setdefault ("mode" , self .mode )
1689
1704
@@ -1726,44 +1741,19 @@ def sample_filter_outputs(
1726
1741
T , R , Q , filter_outputs [0 ], filter_outputs [3 ]
1727
1742
)
1728
1743
1729
- # Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph()
1730
- # filter_output_dims_mapping = {}
1731
- # for k in FILTER_OUTPUT_DIMS.keys():
1732
- # filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k]
1733
-
1734
- all_filter_outputs = filter_outputs [:- 1 ] + list (smoother_outputs )
1735
- # This excludes observed states and observed covariances from the filter outputs
1736
- all_filter_outputs = [
1737
- output for output in all_filter_outputs if output .name in FILTER_OUTPUT_DIMS
1738
- ]
1744
+ filter_outputs = filter_outputs [:- 1 ] + list (smoother_outputs )
1745
+ for output in filter_outputs :
1746
+ if output .name in filter_output_names :
1747
+ dims = all_filter_output_dims [output .name ]
1748
+ pm .Deterministic (output .name , output , dims = dims )
1739
1749
1740
- if filter_output_names is None :
1741
- filter_output_names = all_filter_outputs
1742
- else :
1743
- unknown_filter_output_names = np .setdiff1d (
1744
- filter_output_names , [x .name for x in all_filter_outputs ]
1745
- )
1746
- if unknown_filter_output_names .size > 0 :
1747
- raise ValueError (
1748
- f"{ unknown_filter_output_names } not a valid filter output name!"
1749
- )
1750
- filter_output_names = [
1751
- x for x in all_filter_outputs if x .name in filter_output_names
1752
- ]
1753
-
1754
- for output in filter_output_names :
1755
- dims = FILTER_OUTPUT_DIMS [output .name ]
1756
- pm .Deterministic (output .name , output , dims = dims )
1757
-
1758
- frozen_model = freeze_dims_and_data (m )
1759
- with frozen_model :
1760
- idata_filter = pm .sample_posterior_predictive (
1750
+ with freeze_dims_and_data (m ):
1751
+ return pm .sample_posterior_predictive (
1761
1752
idata if group == "posterior" else idata .prior ,
1762
- var_names = [ x . name for x in frozen_model . deterministics ] ,
1753
+ var_names = filter_output_names ,
1763
1754
compile_kwargs = compile_kwargs ,
1764
1755
** kwargs ,
1765
1756
)
1766
- return idata_filter
1767
1757
1768
1758
@staticmethod
1769
1759
def _validate_forecast_args (
0 commit comments