@@ -1678,6 +1678,91 @@ def sample_statespace_matrices(
16781678
16791679        return  matrix_idata 
16801680
1681+     def  sample_filter_outputs (
1682+         self , idata , filter_output_names : str  |  list [str ] |  None , group : str  =  "posterior" , ** kwargs 
1683+     ):
1684+         compile_kwargs  =  kwargs .pop ("compile_kwargs" , {})
1685+         compile_kwargs .setdefault ("mode" , self .mode )
1686+ 
1687+         with  pm .Model (coords = self .coords ) as  m :
1688+             pm_mod  =  modelcontext (None )
1689+             self ._build_dummy_graph ()
1690+             self ._insert_random_variables ()
1691+ 
1692+             if  self .data_names :
1693+                 for  name  in  self .data_names :
1694+                     pm .Data (** self ._exog_data_info [name ])
1695+ 
1696+             self ._insert_data_variables ()
1697+ 
1698+             x0 , P0 , c , d , T , Z , R , H , Q  =  self .unpack_statespace ()
1699+             data  =  self ._fit_data 
1700+ 
1701+             obs_coords  =  pm_mod .coords .get (OBS_STATE_DIM , None )
1702+ 
1703+             data , nan_mask  =  register_data_with_pymc (
1704+                 data ,
1705+                 n_obs = self .ssm .k_endog ,
1706+                 obs_coords = obs_coords ,
1707+                 register_data = True ,
1708+             )
1709+ 
1710+             filter_outputs  =  self .kalman_filter .build_graph (
1711+                 data ,
1712+                 x0 ,
1713+                 P0 ,
1714+                 c ,
1715+                 d ,
1716+                 T ,
1717+                 Z ,
1718+                 R ,
1719+                 H ,
1720+                 Q ,
1721+             )
1722+ 
1723+             smoother_outputs  =  self .kalman_smoother .build_graph (
1724+                 T , R , Q , filter_outputs [0 ], filter_outputs [3 ]
1725+             )
1726+ 
1727+             all_filter_outputs  =  filter_outputs [:- 1 ] +  list (smoother_outputs )
1728+ 
1729+             if  filter_output_names  is  None :
1730+                 filter_output_names  =  all_filter_outputs 
1731+             else :
1732+                 unknown_filter_output_names  =  np .setdiff1d (
1733+                     filter_output_names , [x .name  for  x  in  all_filter_outputs ]
1734+                 )
1735+                 if  unknown_filter_output_names .size  >  0 :
1736+                     raise  ValueError (
1737+                         f"{ unknown_filter_output_names }  
1738+                     )
1739+                 filter_output_names  =  [
1740+                     x  for  x  in  all_filter_outputs  if  x .name  in  filter_output_names 
1741+                 ]
1742+ 
1743+             for  output  in  filter_output_names :
1744+                 match  output .name :
1745+                     case  "filtered_states"  |  "predicted_states"  |  "smoothed_states" :
1746+                         dims  =  [TIME_DIM , "state" ]
1747+                     case  "filtered_covariances"  |  "predicted_covariances"  |  "smoothed_covariances" :
1748+                         dims  =  [TIME_DIM , "state" , "state_aux" ]
1749+                     case  "observed_states" :
1750+                         dims  =  [TIME_DIM , "observed_state" ]
1751+                     case  "observed_covariances" :
1752+                         dims  =  [TIME_DIM , "observed_state" , "observed_state_aux" ]
1753+ 
1754+                 pm .Deterministic (output .name , output , dims = dims )
1755+ 
1756+         frozen_model  =  freeze_dims_and_data (m )
1757+         with  frozen_model :
1758+             idata_filter  =  pm .sample_posterior_predictive (
1759+                 idata  if  group  ==  "posterior"  else  idata .prior ,
1760+                 var_names = [x .name  for  x  in  frozen_model .deterministics ],
1761+                 compile_kwargs = compile_kwargs ,
1762+                 ** kwargs ,
1763+             )
1764+         return  idata_filter 
1765+ 
16811766    @staticmethod  
16821767    def  _validate_forecast_args (
16831768        time_index : pd .RangeIndex  |  pd .DatetimeIndex ,
0 commit comments