8
8
from pytensor .graph .basic import Variable
9
9
from pytensor .raise_op import Assert
10
10
from pytensor .tensor import TensorVariable
11
- from pytensor .tensor .nlinalg import matrix_dot
12
- from pytensor .tensor .slinalg import solve_discrete_are , solve_triangular
11
+ from pytensor .tensor .slinalg import solve_triangular
13
12
14
13
from pymc_experimental .statespace .filters .utilities import (
15
14
quad_form_sym ,
@@ -55,15 +54,6 @@ def __init__(self, mode=None):
55
54
non_seq_names : list[str]
56
55
A list of names representing static statespace matrices. That is, inputs that will need to be provided
57
56
to the `non_sequences` argument of `pytensor.scan`
58
-
59
- eye_states : TensorVariable
60
- An identity matrix of shape (k_states, k_states), stored for computational efficiency
61
-
62
- eye_posdef : TensorVariable
63
- An identity matrix of shape (k_posdef, k_posdef), stored for computational efficiency
64
-
65
- eye_endog : TensorVariable
66
- An identity matrix of shape (k_endog, k_endog), stored for computational efficiency
67
57
"""
68
58
69
59
self .mode : str = mode
@@ -74,44 +64,9 @@ def __init__(self, mode=None):
74
64
self .n_posdef = None
75
65
self .n_endog = None
76
66
77
- self .eye_states : TensorVariable | None = None
78
- self .eye_posdef : TensorVariable | None = None
79
- self .eye_endog : TensorVariable | None = None
80
67
self .missing_fill_value : float | None = None
81
68
self .cov_jitter = None
82
69
83
- def initialize_eyes (self , R : TensorVariable , Z : TensorVariable ) -> None :
84
- """
85
- Initialize identity matrices for of shapes repeated used in the kalman filtering equations and store them.
86
-
87
- It's surprisingly expensive for pytensor to create an identity matrix every time we need one
88
- (see [1] for benchmarks). This function creates some identity matrices of useful sizes for the model
89
- to re-use as a small optimization.
90
-
91
- Parameters
92
- ----------
93
- R : TensorVariable
94
- The tensor representing the selection matrix, called R in [2]
95
-
96
- Z : TensorVariable
97
- The tensor representing the design matrix, called Z in [2].
98
-
99
- Returns
100
- -------
101
- None
102
-
103
- References
104
- ----------
105
- .. [1] https://gist.github.com/jessegrabowski/acd3235833163943a11654d78a72f04b
106
- .. [2] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
107
- 2nd ed, Oxford University Press, 2012.
108
- """
109
-
110
- self .n_states , self .n_posdef , self .n_endog = R .shape [- 2 ], R .shape [- 1 ], Z .shape [- 2 ]
111
- self .eye_states = pt .eye (self .n_states )
112
- self .eye_posdef = pt .eye (self .n_posdef )
113
- self .eye_endog = pt .eye (self .n_endog )
114
-
115
70
def check_params (self , data , a0 , P0 , c , d , T , Z , R , H , Q ):
116
71
"""
117
72
Apply any checks on validity of inputs. For most filters this is just the identity function.
@@ -141,10 +96,10 @@ def add_check_on_time_varying_shapes(
141
96
list[TensorVariable]
142
97
A list of tensors wrapped in an `Assert` `Op` that checks the shape of the 0th dimension on each is equal
143
98
to the shape of the 0th dimension on the data.
144
-
145
- # TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
146
- the Kalman filter, or in the StateSpaceModel, before passing into the KF?
147
99
"""
100
+ # TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
101
+ # the Kalman filter, or in the StateSpaceModel, before passing into the KF?
102
+
148
103
params_with_assert = [
149
104
assert_time_varying_dim_correct (param , pt .eq (param .shape [0 ], data .shape [0 ]))
150
105
for param in sequence_params
@@ -166,7 +121,7 @@ def unpack_args(self, args) -> tuple:
166
121
args = list (args )
167
122
n_seq = len (self .seq_names )
168
123
if n_seq == 0 :
169
- return args
124
+ return tuple ( args )
170
125
171
126
# The first arg is always y
172
127
y = args .pop (0 )
@@ -202,7 +157,7 @@ def build_graph(
202
157
return_updates = False ,
203
158
missing_fill_value = None ,
204
159
cov_jitter = None ,
205
- ) -> list [TensorVariable ]:
160
+ ) -> list [TensorVariable ] | tuple [ list [ TensorVariable ], dict ] :
206
161
"""
207
162
Construct the computation graph for the Kalman filter. See [1] for details.
208
163
@@ -246,9 +201,11 @@ def build_graph(
246
201
247
202
self .mode = mode
248
203
self .missing_fill_value = missing_fill_value
249
- self .initialize_eyes (R , Z )
250
204
self .cov_jitter = cov_jitter
251
205
206
+ self .n_states , self .n_shocks = R .shape [- 2 :]
207
+ self .n_endog = Z .shape [- 2 ]
208
+
252
209
data , a0 , P0 , * params = self .check_params (data , a0 , P0 , c , d , T , Z , R , H , Q )
253
210
254
211
sequences , non_sequences , seq_names , non_seq_names = split_vars_into_seq_and_nonseq (
@@ -643,7 +600,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
643
600
F = Z .dot (PZT ) + stabilize (H , self .cov_jitter )
644
601
645
602
K = pt .linalg .solve (F .T , PZT .T , assume_a = "pos" , check_finite = False ).T
646
- I_KZ = self .eye_states - K .dot (Z )
603
+ I_KZ = pt . eye ( self .n_states ) - K .dot (Z )
647
604
648
605
a_filtered = a + K .dot (v )
649
606
P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
@@ -662,7 +619,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
662
619
return a_filtered , P_filtered , y_hat , F , ll
663
620
664
621
665
- class CholeskyFilter (BaseFilter ):
622
+ class SquareRootFilter (BaseFilter ):
666
623
"""
667
624
Kalman filter with Cholesky factorization
668
625
@@ -686,7 +643,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
686
643
687
644
# If everything is missing, K = 0, IKZ = I
688
645
K = solve_triangular (F_chol .T , solve_triangular (F_chol , PZT .T )).T
689
- I_KZ = self .eye_states - K .dot (Z )
646
+ I_KZ = pt . eye ( self .n_states ) - K .dot (Z )
690
647
691
648
a_filtered = a + K .dot (v )
692
649
P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
@@ -732,7 +689,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
732
689
F = stabilize (Z .dot (PZT ) + H , self .cov_jitter ).ravel ()
733
690
734
691
K = PZT / F
735
- I_KZ = self .eye_states - K .dot (Z )
692
+ I_KZ = pt . eye ( self .n_states ) - K .dot (Z )
736
693
737
694
a_filtered = a + (K * v ).ravel ()
738
695
@@ -743,123 +700,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
743
700
return a_filtered , P_filtered , pt .atleast_1d (y_hat ), pt .atleast_2d (F ), ll
744
701
745
702
746
- class SteadyStateFilter (BaseFilter ):
747
- """
748
- Kalman Filter using Steady State Covariance
749
-
750
- This filter avoids the need to invert the covariance matrix of innovations at each time step by solving the
751
- Discrete Algebraic Riccati Equation associated with the filtering problem once and for all at initialization and
752
- uses the resulting steady-state covariance matrix in each step.
753
-
754
- The innovation covariance matrix will always converge to the steady state value as T -> oo, so this filter will
755
- only have differences from the standard approach in the early steps (T < 10?). A process of "learning" is lost.
756
- """
757
-
758
- def build_graph (
759
- self ,
760
- data ,
761
- a0 ,
762
- P0 ,
763
- c ,
764
- d ,
765
- T ,
766
- Z ,
767
- R ,
768
- H ,
769
- Q ,
770
- mode = None ,
771
- return_updates = False ,
772
- missing_fill_value = None ,
773
- cov_jitter = None ,
774
- ) -> list [TensorVariable ]:
775
- """
776
- Need to override the base step to add an argument to self.update, passing F_inv at every step.
777
- """
778
- if missing_fill_value is None :
779
- missing_fill_value = MISSING_FILL
780
- if cov_jitter is None :
781
- cov_jitter = JITTER_DEFAULT
782
-
783
- self .mode = mode
784
- self .missing_fill_value = missing_fill_value
785
- self .cov_jitter = cov_jitter
786
- self .initialize_eyes (R , Z )
787
-
788
- data , a0 , P0 , * params = self .check_params (data , a0 , P0 , c , d , T , Z , R , H , Q )
789
- sequences , non_sequences , seq_names , non_seq_names = split_vars_into_seq_and_nonseq (
790
- params , PARAM_NAMES
791
- )
792
- self .seq_names = seq_names
793
- self .non_seq_names = non_seq_names
794
- c , d , T , Z , R , H , Q = params
795
-
796
- if len (sequences ) > 0 :
797
- assert ValueError (
798
- "All system matrices must be time-invariant to use the SteadyStateFilter"
799
- )
800
-
801
- P_steady = solve_discrete_are (T .T , Z .T , matrix_dot (R , Q , R .T ), H )
802
- F = matrix_dot (Z , P_steady , Z .T ) + H
803
- F_inv = pt .linalg .solve (F , pt .eye (F .shape [0 ]), assume_a = "pos" , check_finite = False )
804
-
805
- results , updates = pytensor .scan (
806
- self .kalman_step ,
807
- sequences = [data ],
808
- outputs_info = [None , a0 , None , None , P_steady , None , None ],
809
- non_sequences = [c , d , F_inv , T , Z , R , H , Q ],
810
- name = "forward_kalman_pass" ,
811
- mode = get_mode (self .mode ),
812
- )
813
-
814
- return self ._postprocess_scan_results (results , a0 , P0 , n = data .shape [0 ])
815
-
816
- def update (self , a , P , c , d , F_inv , y , Z , H , all_nan_flag ):
817
- y_hat = Z .dot (a ) + d
818
- v = y - y_hat
819
-
820
- PZT = P .dot (Z .T )
821
-
822
- F = Z .dot (PZT ) + stabilize (H , self .cov_jitter )
823
- K = PZT .dot (F_inv )
824
-
825
- I_KZ = self .eye_states - K .dot (Z )
826
-
827
- a_filtered = a + K .dot (v )
828
- P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
829
-
830
- inner_term = matrix_dot (v .T , F_inv , v )
831
- ll = pt .switch (
832
- all_nan_flag ,
833
- 0.0 ,
834
- - 0.5 * (MVN_CONST + pt .log (pt .linalg .det (F )) + inner_term ).ravel ()[0 ],
835
- )
836
-
837
- return a_filtered , P_filtered , y_hat , F , ll
838
-
839
- def kalman_step (self , y , a , P , c , d , F_inv , T , Z , R , H , Q ):
840
- """
841
- Need to override the base step to add an argument to self.update, passing F_inv at every step.
842
- """
843
-
844
- y_masked , Z_masked , H_masked , all_nan_flag = self .handle_missing_values (y , Z , H )
845
- a_filtered , P_filtered , obs_mu , obs_cov , ll = self .update (
846
- y = y_masked ,
847
- a = a ,
848
- P = P ,
849
- c = c ,
850
- d = d ,
851
- F_inv = F_inv ,
852
- Z = Z_masked ,
853
- H = H_masked ,
854
- all_nan_flag = all_nan_flag ,
855
- )
856
-
857
- P_filtered = stabilize (P_filtered , self .cov_jitter )
858
- a_hat , P_hat = self .predict (a = a_filtered , P = P_filtered , c = c , T = T , R = R , Q = Q )
859
-
860
- return a_filtered , a_hat , obs_mu , P_filtered , P_hat , obs_cov , ll
861
-
862
-
863
703
class UnivariateFilter (BaseFilter ):
864
704
"""
865
705
The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two
0 commit comments