88from pytensor .graph .basic import Variable
99from pytensor .raise_op import Assert
1010from pytensor .tensor import TensorVariable
11- from pytensor .tensor .nlinalg import matrix_dot
12- from pytensor .tensor .slinalg import solve_discrete_are , solve_triangular
11+ from pytensor .tensor .slinalg import solve_triangular
1312
1413from pymc_experimental .statespace .filters .utilities import (
1514 quad_form_sym ,
@@ -55,15 +54,6 @@ def __init__(self, mode=None):
5554 non_seq_names : list[str]
5655 A list of names representing static statespace matrices. That is, inputs that will need to be provided
5756 to the `non_sequences` argument of `pytensor.scan`
58-
59- eye_states : TensorVariable
60- An identity matrix of shape (k_states, k_states), stored for computational efficiency
61-
62- eye_posdef : TensorVariable
63- An identity matrix of shape (k_posdef, k_posdef), stored for computational efficiency
64-
65- eye_endog : TensorVariable
66- An identity matrix of shape (k_endog, k_endog), stored for computational efficiency
6757 """
6858
6959 self .mode : str = mode
@@ -74,44 +64,9 @@ def __init__(self, mode=None):
7464 self .n_posdef = None
7565 self .n_endog = None
7666
77- self .eye_states : TensorVariable | None = None
78- self .eye_posdef : TensorVariable | None = None
79- self .eye_endog : TensorVariable | None = None
8067 self .missing_fill_value : float | None = None
8168 self .cov_jitter = None
8269
83- def initialize_eyes (self , R : TensorVariable , Z : TensorVariable ) -> None :
84- """
85- Initialize identity matrices for of shapes repeated used in the kalman filtering equations and store them.
86-
87- It's surprisingly expensive for pytensor to create an identity matrix every time we need one
88- (see [1] for benchmarks). This function creates some identity matrices of useful sizes for the model
89- to re-use as a small optimization.
90-
91- Parameters
92- ----------
93- R : TensorVariable
94- The tensor representing the selection matrix, called R in [2]
95-
96- Z : TensorVariable
97- The tensor representing the design matrix, called Z in [2].
98-
99- Returns
100- -------
101- None
102-
103- References
104- ----------
105- .. [1] https://gist.github.com/jessegrabowski/acd3235833163943a11654d78a72f04b
106- .. [2] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
107- 2nd ed, Oxford University Press, 2012.
108- """
109-
110- self .n_states , self .n_posdef , self .n_endog = R .shape [- 2 ], R .shape [- 1 ], Z .shape [- 2 ]
111- self .eye_states = pt .eye (self .n_states )
112- self .eye_posdef = pt .eye (self .n_posdef )
113- self .eye_endog = pt .eye (self .n_endog )
114-
11570 def check_params (self , data , a0 , P0 , c , d , T , Z , R , H , Q ):
11671 """
11772 Apply any checks on validity of inputs. For most filters this is just the identity function.
@@ -141,10 +96,10 @@ def add_check_on_time_varying_shapes(
14196 list[TensorVariable]
14297 A list of tensors wrapped in an `Assert` `Op` that checks the shape of the 0th dimension on each is equal
14398 to the shape of the 0th dimension on the data.
144-
145- # TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
146- the Kalman filter, or in the StateSpaceModel, before passing into the KF?
14799 """
100+ # TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
101+ # the Kalman filter, or in the StateSpaceModel, before passing into the KF?
102+
148103 params_with_assert = [
149104 assert_time_varying_dim_correct (param , pt .eq (param .shape [0 ], data .shape [0 ]))
150105 for param in sequence_params
@@ -166,7 +121,7 @@ def unpack_args(self, args) -> tuple:
166121 args = list (args )
167122 n_seq = len (self .seq_names )
168123 if n_seq == 0 :
169- return args
124+ return tuple ( args )
170125
171126 # The first arg is always y
172127 y = args .pop (0 )
@@ -202,7 +157,7 @@ def build_graph(
202157 return_updates = False ,
203158 missing_fill_value = None ,
204159 cov_jitter = None ,
205- ) -> list [TensorVariable ]:
160+ ) -> list [TensorVariable ] | tuple [ list [ TensorVariable ], dict ] :
206161 """
207162 Construct the computation graph for the Kalman filter. See [1] for details.
208163
@@ -246,9 +201,11 @@ def build_graph(
246201
247202 self .mode = mode
248203 self .missing_fill_value = missing_fill_value
249- self .initialize_eyes (R , Z )
250204 self .cov_jitter = cov_jitter
251205
206+ self .n_states , self .n_shocks = R .shape [- 2 :]
207+ self .n_endog = Z .shape [- 2 ]
208+
252209 data , a0 , P0 , * params = self .check_params (data , a0 , P0 , c , d , T , Z , R , H , Q )
253210
254211 sequences , non_sequences , seq_names , non_seq_names = split_vars_into_seq_and_nonseq (
@@ -643,7 +600,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
643600 F = Z .dot (PZT ) + stabilize (H , self .cov_jitter )
644601
645602 K = pt .linalg .solve (F .T , PZT .T , assume_a = "pos" , check_finite = False ).T
646- I_KZ = self .eye_states - K .dot (Z )
603+ I_KZ = pt . eye ( self .n_states ) - K .dot (Z )
647604
648605 a_filtered = a + K .dot (v )
649606 P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
@@ -662,7 +619,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
662619 return a_filtered , P_filtered , y_hat , F , ll
663620
664621
665- class CholeskyFilter (BaseFilter ):
622+ class SquareRootFilter (BaseFilter ):
666623 """
667624 Kalman filter with Cholesky factorization
668625
@@ -686,7 +643,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
686643
687644 # If everything is missing, K = 0, IKZ = I
688645 K = solve_triangular (F_chol .T , solve_triangular (F_chol , PZT .T )).T
689- I_KZ = self .eye_states - K .dot (Z )
646+ I_KZ = pt . eye ( self .n_states ) - K .dot (Z )
690647
691648 a_filtered = a + K .dot (v )
692649 P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
@@ -732,7 +689,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
732689 F = stabilize (Z .dot (PZT ) + H , self .cov_jitter ).ravel ()
733690
734691 K = PZT / F
735- I_KZ = self .eye_states - K .dot (Z )
692+ I_KZ = pt . eye ( self .n_states ) - K .dot (Z )
736693
737694 a_filtered = a + (K * v ).ravel ()
738695
@@ -743,123 +700,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
743700 return a_filtered , P_filtered , pt .atleast_1d (y_hat ), pt .atleast_2d (F ), ll
744701
745702
746- class SteadyStateFilter (BaseFilter ):
747- """
748- Kalman Filter using Steady State Covariance
749-
750- This filter avoids the need to invert the covariance matrix of innovations at each time step by solving the
751- Discrete Algebraic Riccati Equation associated with the filtering problem once and for all at initialization and
752- uses the resulting steady-state covariance matrix in each step.
753-
754- The innovation covariance matrix will always converge to the steady state value as T -> oo, so this filter will
755- only have differences from the standard approach in the early steps (T < 10?). A process of "learning" is lost.
756- """
757-
758- def build_graph (
759- self ,
760- data ,
761- a0 ,
762- P0 ,
763- c ,
764- d ,
765- T ,
766- Z ,
767- R ,
768- H ,
769- Q ,
770- mode = None ,
771- return_updates = False ,
772- missing_fill_value = None ,
773- cov_jitter = None ,
774- ) -> list [TensorVariable ]:
775- """
776- Need to override the base step to add an argument to self.update, passing F_inv at every step.
777- """
778- if missing_fill_value is None :
779- missing_fill_value = MISSING_FILL
780- if cov_jitter is None :
781- cov_jitter = JITTER_DEFAULT
782-
783- self .mode = mode
784- self .missing_fill_value = missing_fill_value
785- self .cov_jitter = cov_jitter
786- self .initialize_eyes (R , Z )
787-
788- data , a0 , P0 , * params = self .check_params (data , a0 , P0 , c , d , T , Z , R , H , Q )
789- sequences , non_sequences , seq_names , non_seq_names = split_vars_into_seq_and_nonseq (
790- params , PARAM_NAMES
791- )
792- self .seq_names = seq_names
793- self .non_seq_names = non_seq_names
794- c , d , T , Z , R , H , Q = params
795-
796- if len (sequences ) > 0 :
797- assert ValueError (
798- "All system matrices must be time-invariant to use the SteadyStateFilter"
799- )
800-
801- P_steady = solve_discrete_are (T .T , Z .T , matrix_dot (R , Q , R .T ), H )
802- F = matrix_dot (Z , P_steady , Z .T ) + H
803- F_inv = pt .linalg .solve (F , pt .eye (F .shape [0 ]), assume_a = "pos" , check_finite = False )
804-
805- results , updates = pytensor .scan (
806- self .kalman_step ,
807- sequences = [data ],
808- outputs_info = [None , a0 , None , None , P_steady , None , None ],
809- non_sequences = [c , d , F_inv , T , Z , R , H , Q ],
810- name = "forward_kalman_pass" ,
811- mode = get_mode (self .mode ),
812- )
813-
814- return self ._postprocess_scan_results (results , a0 , P0 , n = data .shape [0 ])
815-
816- def update (self , a , P , c , d , F_inv , y , Z , H , all_nan_flag ):
817- y_hat = Z .dot (a ) + d
818- v = y - y_hat
819-
820- PZT = P .dot (Z .T )
821-
822- F = Z .dot (PZT ) + stabilize (H , self .cov_jitter )
823- K = PZT .dot (F_inv )
824-
825- I_KZ = self .eye_states - K .dot (Z )
826-
827- a_filtered = a + K .dot (v )
828- P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
829-
830- inner_term = matrix_dot (v .T , F_inv , v )
831- ll = pt .switch (
832- all_nan_flag ,
833- 0.0 ,
834- - 0.5 * (MVN_CONST + pt .log (pt .linalg .det (F )) + inner_term ).ravel ()[0 ],
835- )
836-
837- return a_filtered , P_filtered , y_hat , F , ll
838-
839- def kalman_step (self , y , a , P , c , d , F_inv , T , Z , R , H , Q ):
840- """
841- Need to override the base step to add an argument to self.update, passing F_inv at every step.
842- """
843-
844- y_masked , Z_masked , H_masked , all_nan_flag = self .handle_missing_values (y , Z , H )
845- a_filtered , P_filtered , obs_mu , obs_cov , ll = self .update (
846- y = y_masked ,
847- a = a ,
848- P = P ,
849- c = c ,
850- d = d ,
851- F_inv = F_inv ,
852- Z = Z_masked ,
853- H = H_masked ,
854- all_nan_flag = all_nan_flag ,
855- )
856-
857- P_filtered = stabilize (P_filtered , self .cov_jitter )
858- a_hat , P_hat = self .predict (a = a_filtered , P = P_filtered , c = c , T = T , R = R , Q = Q )
859-
860- return a_filtered , a_hat , obs_mu , P_filtered , P_hat , obs_cov , ll
861-
862-
863703class UnivariateFilter (BaseFilter ):
864704 """
865705 The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two
0 commit comments