44import pytensor
55import pytensor .tensor as pt
66
7+ from pymc .pytensorf import constant_fold
78from pytensor .compile .mode import get_mode
89from pytensor .graph .basic import Variable
910from pytensor .raise_op import Assert
@@ -203,8 +204,11 @@ def build_graph(
203204 self .missing_fill_value = missing_fill_value
204205 self .cov_jitter = cov_jitter
205206
206- self .n_states , self .n_shocks = R .shape [- 2 :]
207- self .n_endog = Z .shape [- 2 ]
207+ [R_shape ] = constant_fold ([R .shape ], raise_not_constant = False )
208+ [Z_shape ] = constant_fold ([Z .shape ], raise_not_constant = False )
209+
210+ self .n_states , self .n_shocks = R_shape [- 2 :]
211+ self .n_endog = Z_shape [- 2 ]
208212
209213 data , a0 , P0 , * params = self .check_params (data , a0 , P0 , c , d , T , Z , R , H , Q )
210214
@@ -408,7 +412,7 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
408412
409413 @staticmethod
410414 def update (
411- a , P , y , c , d , Z , H , all_nan_flag
415+ a , P , y , d , Z , H , all_nan_flag
412416 ) -> tuple [TensorVariable , TensorVariable , TensorVariable , TensorVariable , TensorVariable ]:
413417 """
414418 Perform the update step of the Kalman filter.
@@ -419,7 +423,7 @@ def update(
419423 .. math::
420424
421425 \b egin{align}
422- \\ hat{y}_t &= Z_t a_{t | t-1} \\
426+ \\ hat{y}_t &= Z_t a_{t | t-1} + d_t \\
423427 v_t &= y_t - \\ hat{y}_t \\
424428 F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\
425429 a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\
@@ -435,8 +439,6 @@ def update(
435439 The current covariance matrix estimate, conditioned on information up to time t-1.
436440 y : TensorVariable
437441 The observation data at time t.
438- c : TensorVariable
439- The matrix c.
440442 d : TensorVariable
441443 The matrix d.
442444 Z : TensorVariable
@@ -529,7 +531,7 @@ def kalman_step(self, *args) -> tuple:
529531 y_masked , Z_masked , H_masked , all_nan_flag = self .handle_missing_values (y , Z , H )
530532
531533 a_filtered , P_filtered , obs_mu , obs_cov , ll = self .update (
532- y = y_masked , a = a , c = c , d = d , P = P , Z = Z_masked , H = H_masked , all_nan_flag = all_nan_flag
534+ y = y_masked , a = a , d = d , P = P , Z = Z_masked , H = H_masked , all_nan_flag = all_nan_flag
533535 )
534536
535537 P_filtered = stabilize (P_filtered , self .cov_jitter )
@@ -545,7 +547,7 @@ class StandardFilter(BaseFilter):
545547 Basic Kalman Filter
546548 """
547549
548- def update (self , a , P , y , c , d , Z , H , all_nan_flag ):
550+ def update (self , a , P , y , d , Z , H , all_nan_flag ):
549551 """
550552 Compute one-step forecasts for observed states conditioned on information up to, but not including, the current
551553 timestep, `y_hat`, along with the forcast covariance matrix, `F`. Marginalize over observed states to obtain
@@ -566,9 +568,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
566568 y : TensorVariable
567569 Observations at time t.
568570
569- c : TensorVariable
570- Latent state bias term.
571-
572571 d : TensorVariable
573572 Observed state bias term.
574573
@@ -628,38 +627,128 @@ class SquareRootFilter(BaseFilter):
628627
629628 """
630629
631- # TODO: Can the entire Kalman filter process be re-written, starting from P0_chol, so it's not necessary to compute
632- # cholesky(F) at every iteration?
630+ def predict (self , a , P , c , T , R , Q ):
631+ """
632+ Compute one-step forecasts for the hidden states conditioned on information up to, but not including, the current
633+ timestep, `a_hat`, along with the forcast covariance matrix, `P_hat`.
634+
635+ .. warning::
636+ Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
637+ covariance matrix itself. The name `P` is kept for consistency with the superclass.
638+ """
639+ # Rename P to P_chol for clarity
640+ P_chol = P
641+
642+ a_hat = T .dot (a ) + c
643+ Q_chol = pt .linalg .cholesky (Q , lower = True )
644+
645+ M = pt .horizontal_stack (T @ P_chol , R @ Q_chol ).T
646+ R_decomp = pt .linalg .qr (M , mode = "r" )
647+ P_chol_hat = R_decomp [: self .n_states , : self .n_states ].T
648+
649+ return a_hat , P_chol_hat
650+
651+ def update (self , a , P , y , d , Z , H , all_nan_flag ):
652+ """
653+ Compute posterior estimates of the hidden state distributions conditioned on the observed data, up to and
654+ including the present timestep. Also compute the log-likelihood of the data given the one-step forecasts.
655+
656+ .. warning::
657+ Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
658+ covariance matrix itself. The name `P` is kept for consistency with the superclass.
659+ """
660+
661+ # Rename P to P_chol for clarity
662+ P_chol = P
633663
634- def update (self , a , P , y , c , d , Z , H , all_nan_flag ):
635664 y_hat = Z .dot (a ) + d
636665 v = y - y_hat
637666
638- PZT = P .dot (Z .T )
667+ H_chol = pytensor .ifelse (pt .all (pt .eq (H , 0.0 )), H , pt .linalg .cholesky (H , lower = True ))
668+
669+ # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
670+ # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
671+ # [0, L_pred]]
672+ # The Schur decomposition of this matrix will be B (upper triangular). We are
673+ # more insterested in B^T:
674+ # Structure of B^T = [[chol(F), 0 ],
675+ # [K @ chol(F), chol(P_filtered)]
676+ zeros = pt .zeros ((self .n_states , self .n_endog ))
677+ upper = pt .horizontal_stack (H_chol , Z @ P_chol )
678+ lower = pt .horizontal_stack (zeros , P_chol )
679+ A_T = pt .vertical_stack (upper , lower )
680+ B = pt .linalg .qr (A_T .T , mode = "r" ).T
681+
682+ F_chol = B [: self .n_endog , : self .n_endog ]
683+ K_F_chol = B [self .n_endog :, : self .n_endog ]
684+ P_chol_filtered = B [self .n_endog :, self .n_endog :]
685+
686+ def compute_non_degenerate (P_chol_filtered , F_chol , K_F_chol , v ):
687+ a_filtered = a + K_F_chol @ solve_triangular (F_chol , v , lower = True )
688+
689+ inner_term = solve_triangular (
690+ F_chol , solve_triangular (F_chol , v , lower = True ), lower = True
691+ )
692+ loss = (v .T @ inner_term ).ravel ()
693+
694+ # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
695+ logdet = 2 * pt .log (pt .abs (pt .diag (F_chol ))).sum ()
696+
697+ ll = - 0.5 * (self .n_endog * (MVN_CONST + logdet ) + loss )[0 ]
698+
699+ return [a_filtered , P_chol_filtered , ll ]
700+
701+ def compute_degenerate (P_chol_filtered , F_chol , K_F_chol , v ):
702+ """
703+ If F is zero (usually because there were no observations this period), then we want:
704+ K = 0, a = a, P = P, ll = 0
705+ """
706+ return [a , P_chol , pt .zeros (())]
707+
708+ [a_filtered , P_chol_filtered , ll ] = pytensor .ifelse (
709+ pt .eq (all_nan_flag , 1.0 ),
710+ compute_degenerate (P_chol_filtered , F_chol , K_F_chol , v ),
711+ compute_non_degenerate (P_chol_filtered , F_chol , K_F_chol , v ),
712+ )
639713
640- # If everything is missing, F will be [[0]] and F_chol will raise an error, so add identity to avoid the error
641- F = Z .dot (PZT ) + stabilize (H , self .cov_jitter )
642- F_chol = pt .linalg .cholesky (F )
714+ a_filtered = pt .specify_shape (a_filtered , (self .n_states ,))
715+ P_chol_filtered = pt .specify_shape (P_chol_filtered , (self .n_states , self .n_states ))
643716
644- # If everything is missing, K = 0, IKZ = I
645- K = solve_triangular (F_chol .T , solve_triangular (F_chol , PZT .T )).T
646- I_KZ = pt .eye (self .n_states ) - K .dot (Z )
717+ return a_filtered , P_chol_filtered , y_hat , F_chol , ll
647718
648- a_filtered = a + K .dot (v )
649- P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
719+ def _postprocess_scan_results (self , results , a0 , P0 , n ) -> list [TensorVariable ]:
720+ """
721+ Convert the Cholesky factor of the covariance matrix back to the covariance matrix itself.
722+ """
723+ results = super ()._postprocess_scan_results (results , a0 , P0 , n )
724+ (
725+ filtered_states ,
726+ predicted_states ,
727+ observed_states ,
728+ filtered_covariances_cholesky ,
729+ predicted_covariances_cholesky ,
730+ observed_covariances_cholesky ,
731+ loglike_obs ,
732+ ) = results
650733
651- inner_term = solve_triangular (F_chol .T , solve_triangular (F_chol , v ))
652- n = y .shape [0 ]
734+ def square_sequnece (L ):
735+ X = pt .einsum ("...ij,...kj->...ik" , L , L .copy ())
736+ X = pt .specify_shape (X , (n , self .n_states , self .n_states ))
737+ return X
653738
654- ll = pt .switch (
655- all_nan_flag ,
656- 0.0 ,
657- (
658- - 0.5 * (n * MVN_CONST + (v .T @ inner_term ).ravel ()) - pt .log (pt .diag (F_chol )).sum ()
659- ).ravel ()[0 ],
660- )
739+ filtered_covariances = square_sequnece (filtered_covariances_cholesky )
740+ predicted_covariances = square_sequnece (predicted_covariances_cholesky )
741+ observed_covariances = square_sequnece (observed_covariances_cholesky )
661742
662- return a_filtered , P_filtered , y_hat , F , ll
743+ return [
744+ filtered_states ,
745+ predicted_states ,
746+ observed_states ,
747+ filtered_covariances ,
748+ predicted_covariances ,
749+ observed_covariances ,
750+ loglike_obs ,
751+ ]
663752
664753
665754class SingleTimeseriesFilter (BaseFilter ):
@@ -679,7 +768,7 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
679768
680769 return data , a0 , P0 , c , d , T , Z , R , H , Q
681770
682- def update (self , a , P , y , c , d , Z , H , all_nan_flag ):
771+ def update (self , a , P , y , d , Z , H , all_nan_flag ):
683772 y_hat = d + Z .dot (a )
684773 v = y - y_hat .ravel ()
685774
0 commit comments