@@ -118,6 +118,8 @@ class PosteriorGSSMFiltered(NamedTuple):
118118 :param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$
119119 :param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$
120120 :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$
121+ :param predicted_means: array of predicted means $\mathbb{E}[z_t \mid y_{1:t-1}, u_{1:t-1}]$
122+ :param predicted_covariances: array of predicted covariances $\mathrm{Cov}[z_t \mid y_{1:t-1}, u_{1:t-1}]$
121123
122124 """
123125 marginal_loglik : Union [Scalar , Float [Array , " ntime" ]]
@@ -504,12 +506,12 @@ def _step(carry, t):
504506 # Predict the next state
505507 pred_mean , pred_cov = _predict (filtered_mean , filtered_cov , F , B , b , Q , u )
506508
507- return (ll , pred_mean , pred_cov ), (filtered_mean , filtered_cov )
509+ return (ll , pred_mean , pred_cov ), (filtered_mean , filtered_cov , carry [ 1 ], carry [ 2 ] )
508510
509511 # Run the Kalman filter
510512 carry = (0.0 , params .initial .mean , params .initial .cov )
511- (ll , _ , _ ), (filtered_means , filtered_covs ) = lax .scan (_step , carry , jnp .arange (num_timesteps ))
512- return PosteriorGSSMFiltered (marginal_loglik = ll , filtered_means = filtered_means , filtered_covariances = filtered_covs )
513+ (ll , _ , _ ), (filtered_means , filtered_covs , predicted_means , predicted_covs ) = lax .scan (_step , carry , jnp .arange (num_timesteps ))
514+ return PosteriorGSSMFiltered (marginal_loglik = ll , filtered_means = filtered_means , filtered_covariances = filtered_covs , predicted_means = predicted_means , predicted_covariances = predicted_covs )
513515
514516
515517@preprocess_args
0 commit comments