Skip to content

Commit fcc4f3a

Browse files
committed
Enhance Kalman filter return values with predictions
Added predicted means and covariances to the return values of the Kalman filter.
1 parent 11c1c9f commit fcc4f3a

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

dynamax/linear_gaussian_ssm/inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ class PosteriorGSSMFiltered(NamedTuple):
118118
:param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$
119119
:param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$
120120
:param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$
121+
:param predicted_means: array of predicted means $\mathbb{E}[z_t \mid y_{1:t-1}, u_{1:t-1}]$
122+
:param predicted_covariances: array of predicted covariances $\mathrm{Cov}[z_t \mid y_{1:t-1}, u_{1:t-1}]$
121123
122124
"""
123125
marginal_loglik: Union[Scalar, Float[Array, " ntime"]]
@@ -504,12 +506,12 @@ def _step(carry, t):
504506
# Predict the next state
505507
pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F, B, b, Q, u)
506508

507-
return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov)
509+
return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov, carry[1], carry[2])
508510

509511
# Run the Kalman filter
510512
carry = (0.0, params.initial.mean, params.initial.cov)
511-
(ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
512-
return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)
513+
(ll, _, _), (filtered_means, filtered_covs, predicted_means, predicted_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
514+
return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs, predicted_means=predicted_means, predicted_covariances=predicted_covs)
513515

514516

515517
@preprocess_args

0 commit comments

Comments
 (0)