@@ -1678,6 +1678,91 @@ def sample_statespace_matrices(
1678
1678
1679
1679
return matrix_idata
1680
1680
1681
+ def sample_filter_outputs (
1682
+ self , idata , filter_output_names : str | list [str ] | None , group : str = "posterior" , ** kwargs
1683
+ ):
1684
+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1685
+ compile_kwargs .setdefault ("mode" , self .mode )
1686
+
1687
+ with pm .Model (coords = self .coords ) as m :
1688
+ pm_mod = modelcontext (None )
1689
+ self ._build_dummy_graph ()
1690
+ self ._insert_random_variables ()
1691
+
1692
+ if self .data_names :
1693
+ for name in self .data_names :
1694
+ pm .Data (** self ._exog_data_info [name ])
1695
+
1696
+ self ._insert_data_variables ()
1697
+
1698
+ x0 , P0 , c , d , T , Z , R , H , Q = self .unpack_statespace ()
1699
+ data = self ._fit_data
1700
+
1701
+ obs_coords = pm_mod .coords .get (OBS_STATE_DIM , None )
1702
+
1703
+ data , nan_mask = register_data_with_pymc (
1704
+ data ,
1705
+ n_obs = self .ssm .k_endog ,
1706
+ obs_coords = obs_coords ,
1707
+ register_data = True ,
1708
+ )
1709
+
1710
+ filter_outputs = self .kalman_filter .build_graph (
1711
+ data ,
1712
+ x0 ,
1713
+ P0 ,
1714
+ c ,
1715
+ d ,
1716
+ T ,
1717
+ Z ,
1718
+ R ,
1719
+ H ,
1720
+ Q ,
1721
+ )
1722
+
1723
+ smoother_outputs = self .kalman_smoother .build_graph (
1724
+ T , R , Q , filter_outputs [0 ], filter_outputs [3 ]
1725
+ )
1726
+
1727
+ all_filter_outputs = filter_outputs [:- 1 ] + list (smoother_outputs )
1728
+
1729
+ if filter_output_names is None :
1730
+ filter_output_names = all_filter_outputs
1731
+ else :
1732
+ unknown_filter_output_names = np .setdiff1d (
1733
+ filter_output_names , [x .name for x in all_filter_outputs ]
1734
+ )
1735
+ if unknown_filter_output_names .size > 0 :
1736
+ raise ValueError (
1737
+ f"{ unknown_filter_output_names } not a valid filter output name!"
1738
+ )
1739
+ filter_output_names = [
1740
+ x for x in all_filter_outputs if x .name in filter_output_names
1741
+ ]
1742
+
1743
+ for output in filter_output_names :
1744
+ match output .name :
1745
+ case "filtered_states" | "predicted_states" | "smoothed_states" :
1746
+ dims = [TIME_DIM , "state" ]
1747
+ case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances" :
1748
+ dims = [TIME_DIM , "state" , "state_aux" ]
1749
+ case "observed_states" :
1750
+ dims = [TIME_DIM , "observed_state" ]
1751
+ case "observed_covariances" :
1752
+ dims = [TIME_DIM , "observed_state" , "observed_state_aux" ]
1753
+
1754
+ pm .Deterministic (output .name , output , dims = dims )
1755
+
1756
+ frozen_model = freeze_dims_and_data (m )
1757
+ with frozen_model :
1758
+ idata_filter = pm .sample_posterior_predictive (
1759
+ idata if group == "posterior" else idata .prior ,
1760
+ var_names = [x .name for x in frozen_model .deterministics ],
1761
+ compile_kwargs = compile_kwargs ,
1762
+ ** kwargs ,
1763
+ )
1764
+ return idata_filter
1765
+
1681
1766
@staticmethod
1682
1767
def _validate_forecast_args (
1683
1768
time_index : pd .RangeIndex | pd .DatetimeIndex ,
0 commit comments