@@ -1681,11 +1681,13 @@ def sample_statespace_matrices(
1681
1681
def sample_filter_outputs (
1682
1682
self , idata , filter_output_names : str | list [str ] | None , group : str = "posterior" , ** kwargs
1683
1683
):
1684
+ if isinstance (filter_output_names , str ):
1685
+ filter_output_names = [filter_output_names ]
1686
+
1684
1687
compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1685
1688
compile_kwargs .setdefault ("mode" , self .mode )
1686
1689
1687
1690
with pm .Model (coords = self .coords ) as m :
1688
- pm_mod = modelcontext (None )
1689
1691
self ._build_dummy_graph ()
1690
1692
self ._insert_random_variables ()
1691
1693
@@ -1698,7 +1700,7 @@ def sample_filter_outputs(
1698
1700
x0 , P0 , c , d , T , Z , R , H , Q = self .unpack_statespace ()
1699
1701
data = self ._fit_data
1700
1702
1701
- obs_coords = pm_mod .coords .get (OBS_STATE_DIM , None )
1703
+ obs_coords = m .coords .get (OBS_STATE_DIM , None )
1702
1704
1703
1705
data , nan_mask = register_data_with_pymc (
1704
1706
data ,
@@ -1724,7 +1726,16 @@ def sample_filter_outputs(
1724
1726
T , R , Q , filter_outputs [0 ], filter_outputs [3 ]
1725
1727
)
1726
1728
1729
+ # Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph()
1730
+ filter_output_dims_mapping = {}
1731
+ for k in FILTER_OUTPUT_DIMS .keys ():
1732
+ filter_output_dims_mapping [k + "s" ] = FILTER_OUTPUT_DIMS [k ]
1733
+
1727
1734
all_filter_outputs = filter_outputs [:- 1 ] + list (smoother_outputs )
1735
+ # This excludes observed states and observed covariances from the filter outputs
1736
+ all_filter_outputs = [
1737
+ output for output in all_filter_outputs if output .name in filter_output_dims_mapping
1738
+ ]
1728
1739
1729
1740
if filter_output_names is None :
1730
1741
filter_output_names = all_filter_outputs
@@ -1741,16 +1752,7 @@ def sample_filter_outputs(
1741
1752
]
1742
1753
1743
1754
for output in filter_output_names :
1744
- match output .name :
1745
- case "filtered_states" | "predicted_states" | "smoothed_states" :
1746
- dims = [TIME_DIM , "state" ]
1747
- case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances" :
1748
- dims = [TIME_DIM , "state" , "state_aux" ]
1749
- case "observed_states" :
1750
- dims = [TIME_DIM , "observed_state" ]
1751
- case "observed_covariances" :
1752
- dims = [TIME_DIM , "observed_state" , "observed_state_aux" ]
1753
-
1755
+ dims = filter_output_dims_mapping [output .name ]
1754
1756
pm .Deterministic (output .name , output , dims = dims )
1755
1757
1756
1758
frozen_model = freeze_dims_and_data (m )
0 commit comments