77
88from pymc_extras .statespace .filters import (
99 KalmanSmoother ,
10+ SquareRootFilter ,
1011 StandardFilter ,
1112 UnivariateFilter ,
1213)
3031RTOL = 1e-6 if floatX .endswith ("64" ) else 1e-3
3132
3233standard_inout = initialize_filter (StandardFilter ())
33- # cholesky_inout = initialize_filter(CholeskyFilter ())
34+ cholesky_inout = initialize_filter (SquareRootFilter ())
3435univariate_inout = initialize_filter (UnivariateFilter ())
3536
3637f_standard = pytensor .function (* standard_inout , on_unused_input = "ignore" )
37- # f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
38+ f_cholesky = pytensor .function (* cholesky_inout , on_unused_input = "ignore" )
3839f_univariate = pytensor .function (* univariate_inout , on_unused_input = "ignore" )
3940
40- filter_funcs = [f_standard , f_univariate ]
41+ filter_funcs = [f_standard , f_cholesky , f_univariate ]
4142
4243filter_names = [
4344 "StandardFilter" ,
45+ "CholeskyFilter" ,
4446 "UnivariateFilter" ,
4547]
4648
@@ -229,8 +231,8 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
229231@pytest .mark .skipif (floatX == "float32" , reason = "Tests are too sensitive for float32" )
230232def test_filters_match_statsmodel_output (filter_func , filter_name , n_missing , rng ):
231233 fit_sm_mod , [data , a0 , P0 , c , d , T , Z , R , H , Q ] = nile_test_test_helper (rng , n_missing )
232- # if filter_name == "CholeskyFilter":
233- # P0 = np.linalg.cholesky(P0)
234+ if filter_name == "CholeskyFilter" :
235+ P0 = np .linalg .cholesky (P0 )
234236 inputs = [data , a0 , P0 , c , d , T , Z , R , H , Q ]
235237 outputs = filter_func (* inputs )
236238
@@ -278,8 +280,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
278280 pytest .skip ("Univariate filter not stable at half precision without measurement error" )
279281
280282 fit_sm_mod , [data , a0 , P0 , c , d , T , Z , R , H , Q ] = nile_test_test_helper (rng , n_missing )
281- # if filter_name == "CholeskyFilter":
282- # P0 = np.linalg.cholesky(P0)
283+ if filter_name == "CholeskyFilter" :
284+ P0 = np .linalg .cholesky (P0 )
283285
284286 H *= int (obs_noise )
285287 inputs = [data , a0 , P0 , c , d , T , Z , R , H , Q ]
@@ -301,8 +303,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
301303
302304@pytest .mark .parametrize (
303305 "filter" ,
304- [StandardFilter ],
305- ids = ["standard" ],
306+ [StandardFilter , SquareRootFilter ],
307+ ids = ["standard" , "cholesky" ],
306308)
307309def test_kalman_filter_jax (filter ):
308310 pytest .importorskip ("jax" )
0 commit comments