4
4
import pytensor
5
5
import pytensor .tensor as pt
6
6
7
+ from pymc .pytensorf import constant_fold
7
8
from pytensor .compile .mode import get_mode
8
9
from pytensor .graph .basic import Variable
9
10
from pytensor .raise_op import Assert
@@ -203,8 +204,11 @@ def build_graph(
203
204
self .missing_fill_value = missing_fill_value
204
205
self .cov_jitter = cov_jitter
205
206
206
- self .n_states , self .n_shocks = R .shape [- 2 :]
207
- self .n_endog = Z .shape [- 2 ]
207
+ [R_shape ] = constant_fold ([R .shape ], raise_not_constant = False )
208
+ [Z_shape ] = constant_fold ([Z .shape ], raise_not_constant = False )
209
+
210
+ self .n_states , self .n_shocks = R_shape [- 2 :]
211
+ self .n_endog = Z_shape [- 2 ]
208
212
209
213
data , a0 , P0 , * params = self .check_params (data , a0 , P0 , c , d , T , Z , R , H , Q )
210
214
@@ -408,7 +412,7 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
408
412
409
413
@staticmethod
410
414
def update (
411
- a , P , y , c , d , Z , H , all_nan_flag
415
+ a , P , y , d , Z , H , all_nan_flag
412
416
) -> tuple [TensorVariable , TensorVariable , TensorVariable , TensorVariable , TensorVariable ]:
413
417
"""
414
418
Perform the update step of the Kalman filter.
@@ -419,7 +423,7 @@ def update(
419
423
.. math::
420
424
421
425
\b egin{align}
422
- \\ hat{y}_t &= Z_t a_{t | t-1} \\
426
+ \\ hat{y}_t &= Z_t a_{t | t-1} + d_t \\
423
427
v_t &= y_t - \\ hat{y}_t \\
424
428
F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\
425
429
a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\
@@ -435,8 +439,6 @@ def update(
435
439
The current covariance matrix estimate, conditioned on information up to time t-1.
436
440
y : TensorVariable
437
441
The observation data at time t.
438
- c : TensorVariable
439
- The matrix c.
440
442
d : TensorVariable
441
443
The matrix d.
442
444
Z : TensorVariable
@@ -529,7 +531,7 @@ def kalman_step(self, *args) -> tuple:
529
531
y_masked , Z_masked , H_masked , all_nan_flag = self .handle_missing_values (y , Z , H )
530
532
531
533
a_filtered , P_filtered , obs_mu , obs_cov , ll = self .update (
532
- y = y_masked , a = a , c = c , d = d , P = P , Z = Z_masked , H = H_masked , all_nan_flag = all_nan_flag
534
+ y = y_masked , a = a , d = d , P = P , Z = Z_masked , H = H_masked , all_nan_flag = all_nan_flag
533
535
)
534
536
535
537
P_filtered = stabilize (P_filtered , self .cov_jitter )
@@ -545,7 +547,7 @@ class StandardFilter(BaseFilter):
545
547
Basic Kalman Filter
546
548
"""
547
549
548
- def update (self , a , P , y , c , d , Z , H , all_nan_flag ):
550
+ def update (self , a , P , y , d , Z , H , all_nan_flag ):
549
551
"""
550
552
Compute one-step forecasts for observed states conditioned on information up to, but not including, the current
551
553
timestep, `y_hat`, along with the forcast covariance matrix, `F`. Marginalize over observed states to obtain
@@ -566,9 +568,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
566
568
y : TensorVariable
567
569
Observations at time t.
568
570
569
- c : TensorVariable
570
- Latent state bias term.
571
-
572
571
d : TensorVariable
573
572
Observed state bias term.
574
573
@@ -628,38 +627,128 @@ class SquareRootFilter(BaseFilter):
628
627
629
628
"""
630
629
631
- # TODO: Can the entire Kalman filter process be re-written, starting from P0_chol, so it's not necessary to compute
632
- # cholesky(F) at every iteration?
630
+ def predict (self , a , P , c , T , R , Q ):
631
+ """
632
+ Compute one-step forecasts for the hidden states conditioned on information up to, but not including, the current
633
+ timestep, `a_hat`, along with the forcast covariance matrix, `P_hat`.
634
+
635
+ .. warning::
636
+ Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
637
+ covariance matrix itself. The name `P` is kept for consistency with the superclass.
638
+ """
639
+ # Rename P to P_chol for clarity
640
+ P_chol = P
641
+
642
+ a_hat = T .dot (a ) + c
643
+ Q_chol = pt .linalg .cholesky (Q , lower = True )
644
+
645
+ M = pt .horizontal_stack (T @ P_chol , R @ Q_chol ).T
646
+ R_decomp = pt .linalg .qr (M , mode = "r" )
647
+ P_chol_hat = R_decomp [: self .n_states , : self .n_states ].T
648
+
649
+ return a_hat , P_chol_hat
650
+
651
+ def update (self , a , P , y , d , Z , H , all_nan_flag ):
652
+ """
653
+ Compute posterior estimates of the hidden state distributions conditioned on the observed data, up to and
654
+ including the present timestep. Also compute the log-likelihood of the data given the one-step forecasts.
655
+
656
+ .. warning::
657
+ Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
658
+ covariance matrix itself. The name `P` is kept for consistency with the superclass.
659
+ """
660
+
661
+ # Rename P to P_chol for clarity
662
+ P_chol = P
633
663
634
- def update (self , a , P , y , c , d , Z , H , all_nan_flag ):
635
664
y_hat = Z .dot (a ) + d
636
665
v = y - y_hat
637
666
638
- PZT = P .dot (Z .T )
667
+ H_chol = pytensor .ifelse (pt .all (pt .eq (H , 0.0 )), H , pt .linalg .cholesky (H , lower = True ))
668
+
669
+ # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
670
+ # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
671
+ # [0, L_pred]]
672
+ # The Schur decomposition of this matrix will be B (upper triangular). We are
673
+ # more insterested in B^T:
674
+ # Structure of B^T = [[chol(F), 0 ],
675
+ # [K @ chol(F), chol(P_filtered)]
676
+ zeros = pt .zeros ((self .n_states , self .n_endog ))
677
+ upper = pt .horizontal_stack (H_chol , Z @ P_chol )
678
+ lower = pt .horizontal_stack (zeros , P_chol )
679
+ A_T = pt .vertical_stack (upper , lower )
680
+ B = pt .linalg .qr (A_T .T , mode = "r" ).T
681
+
682
+ F_chol = B [: self .n_endog , : self .n_endog ]
683
+ K_F_chol = B [self .n_endog :, : self .n_endog ]
684
+ P_chol_filtered = B [self .n_endog :, self .n_endog :]
685
+
686
+ def compute_non_degenerate (P_chol_filtered , F_chol , K_F_chol , v ):
687
+ a_filtered = a + K_F_chol @ solve_triangular (F_chol , v , lower = True )
688
+
689
+ inner_term = solve_triangular (
690
+ F_chol , solve_triangular (F_chol , v , lower = True ), lower = True
691
+ )
692
+ loss = (v .T @ inner_term ).ravel ()
693
+
694
+ # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
695
+ logdet = 2 * pt .log (pt .abs (pt .diag (F_chol ))).sum ()
696
+
697
+ ll = - 0.5 * (self .n_endog * (MVN_CONST + logdet ) + loss )[0 ]
698
+
699
+ return [a_filtered , P_chol_filtered , ll ]
700
+
701
+ def compute_degenerate (P_chol_filtered , F_chol , K_F_chol , v ):
702
+ """
703
+ If F is zero (usually because there were no observations this period), then we want:
704
+ K = 0, a = a, P = P, ll = 0
705
+ """
706
+ return [a , P_chol , pt .zeros (())]
707
+
708
+ [a_filtered , P_chol_filtered , ll ] = pytensor .ifelse (
709
+ pt .eq (all_nan_flag , 1.0 ),
710
+ compute_degenerate (P_chol_filtered , F_chol , K_F_chol , v ),
711
+ compute_non_degenerate (P_chol_filtered , F_chol , K_F_chol , v ),
712
+ )
639
713
640
- # If everything is missing, F will be [[0]] and F_chol will raise an error, so add identity to avoid the error
641
- F = Z .dot (PZT ) + stabilize (H , self .cov_jitter )
642
- F_chol = pt .linalg .cholesky (F )
714
+ a_filtered = pt .specify_shape (a_filtered , (self .n_states ,))
715
+ P_chol_filtered = pt .specify_shape (P_chol_filtered , (self .n_states , self .n_states ))
643
716
644
- # If everything is missing, K = 0, IKZ = I
645
- K = solve_triangular (F_chol .T , solve_triangular (F_chol , PZT .T )).T
646
- I_KZ = pt .eye (self .n_states ) - K .dot (Z )
717
+ return a_filtered , P_chol_filtered , y_hat , F_chol , ll
647
718
648
- a_filtered = a + K .dot (v )
649
- P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
719
+ def _postprocess_scan_results (self , results , a0 , P0 , n ) -> list [TensorVariable ]:
720
+ """
721
+ Convert the Cholesky factor of the covariance matrix back to the covariance matrix itself.
722
+ """
723
+ results = super ()._postprocess_scan_results (results , a0 , P0 , n )
724
+ (
725
+ filtered_states ,
726
+ predicted_states ,
727
+ observed_states ,
728
+ filtered_covariances_cholesky ,
729
+ predicted_covariances_cholesky ,
730
+ observed_covariances_cholesky ,
731
+ loglike_obs ,
732
+ ) = results
650
733
651
- inner_term = solve_triangular (F_chol .T , solve_triangular (F_chol , v ))
652
- n = y .shape [0 ]
734
+ def square_sequnece (L ):
735
+ X = pt .einsum ("...ij,...kj->...ik" , L , L .copy ())
736
+ X = pt .specify_shape (X , (n , self .n_states , self .n_states ))
737
+ return X
653
738
654
- ll = pt .switch (
655
- all_nan_flag ,
656
- 0.0 ,
657
- (
658
- - 0.5 * (n * MVN_CONST + (v .T @ inner_term ).ravel ()) - pt .log (pt .diag (F_chol )).sum ()
659
- ).ravel ()[0 ],
660
- )
739
+ filtered_covariances = square_sequnece (filtered_covariances_cholesky )
740
+ predicted_covariances = square_sequnece (predicted_covariances_cholesky )
741
+ observed_covariances = square_sequnece (observed_covariances_cholesky )
661
742
662
- return a_filtered , P_filtered , y_hat , F , ll
743
+ return [
744
+ filtered_states ,
745
+ predicted_states ,
746
+ observed_states ,
747
+ filtered_covariances ,
748
+ predicted_covariances ,
749
+ observed_covariances ,
750
+ loglike_obs ,
751
+ ]
663
752
664
753
665
754
class SingleTimeseriesFilter (BaseFilter ):
@@ -679,7 +768,7 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
679
768
680
769
return data , a0 , P0 , c , d , T , Z , R , H , Q
681
770
682
- def update (self , a , P , y , c , d , Z , H , all_nan_flag ):
771
+ def update (self , a , P , y , d , Z , H , all_nan_flag ):
683
772
y_hat = d + Z .dot (a )
684
773
v = y - y_hat .ravel ()
685
774
0 commit comments