Skip to content

Commit ca9b8b6

Browse files
Remove SteadyStateFilder
Rename `CholeskyFilter` to `SquareRootFilter` to match the literature
1 parent 4deeec6 commit ca9b8b6

File tree

4 files changed

+22
-187
lines changed

4 files changed

+22
-187
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from pymc_experimental.statespace.core.representation import PytensorRepresentation
2020
from pymc_experimental.statespace.filters import (
21-
CholeskyFilter,
2221
KalmanSmoother,
2322
SingleTimeseriesFilter,
23+
SquareRootFilter,
2424
StandardFilter,
2525
SteadyStateFilter,
2626
UnivariateFilter,
@@ -55,7 +55,7 @@
5555
"univariate": UnivariateFilter,
5656
"steady_state": SteadyStateFilter,
5757
"single": SingleTimeseriesFilter,
58-
"cholesky": CholeskyFilter,
58+
"cholesky": SquareRootFilter,
5959
}
6060

6161

pymc_experimental/statespace/filters/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace
22
from pymc_experimental.statespace.filters.kalman_filter import (
3-
CholeskyFilter,
43
SingleTimeseriesFilter,
4+
SquareRootFilter,
55
StandardFilter,
66
SteadyStateFilter,
77
UnivariateFilter,
@@ -14,6 +14,6 @@
1414
"SteadyStateFilter",
1515
"KalmanSmoother",
1616
"SingleTimeseriesFilter",
17-
"CholeskyFilter",
17+
"SquareRootFilter",
1818
"LinearGaussianStateSpace",
1919
]

pymc_experimental/statespace/filters/kalman_filter.py

Lines changed: 13 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from pytensor.graph.basic import Variable
99
from pytensor.raise_op import Assert
1010
from pytensor.tensor import TensorVariable
11-
from pytensor.tensor.nlinalg import matrix_dot
12-
from pytensor.tensor.slinalg import solve_discrete_are, solve_triangular
11+
from pytensor.tensor.slinalg import solve_triangular
1312

1413
from pymc_experimental.statespace.filters.utilities import (
1514
quad_form_sym,
@@ -55,15 +54,6 @@ def __init__(self, mode=None):
5554
non_seq_names : list[str]
5655
A list of names representing static statespace matrices. That is, inputs that will need to be provided
5756
to the `non_sequences` argument of `pytensor.scan`
58-
59-
eye_states : TensorVariable
60-
An identity matrix of shape (k_states, k_states), stored for computational efficiency
61-
62-
eye_posdef : TensorVariable
63-
An identity matrix of shape (k_posdef, k_posdef), stored for computational efficiency
64-
65-
eye_endog : TensorVariable
66-
An identity matrix of shape (k_endog, k_endog), stored for computational efficiency
6757
"""
6858

6959
self.mode: str = mode
@@ -74,44 +64,9 @@ def __init__(self, mode=None):
7464
self.n_posdef = None
7565
self.n_endog = None
7666

77-
self.eye_states: TensorVariable | None = None
78-
self.eye_posdef: TensorVariable | None = None
79-
self.eye_endog: TensorVariable | None = None
8067
self.missing_fill_value: float | None = None
8168
self.cov_jitter = None
8269

83-
def initialize_eyes(self, R: TensorVariable, Z: TensorVariable) -> None:
84-
"""
85-
Initialize identity matrices for of shapes repeated used in the kalman filtering equations and store them.
86-
87-
It's surprisingly expensive for pytensor to create an identity matrix every time we need one
88-
(see [1] for benchmarks). This function creates some identity matrices of useful sizes for the model
89-
to re-use as a small optimization.
90-
91-
Parameters
92-
----------
93-
R : TensorVariable
94-
The tensor representing the selection matrix, called R in [2]
95-
96-
Z : TensorVariable
97-
The tensor representing the design matrix, called Z in [2].
98-
99-
Returns
100-
-------
101-
None
102-
103-
References
104-
----------
105-
.. [1] https://gist.github.com/jessegrabowski/acd3235833163943a11654d78a72f04b
106-
.. [2] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
107-
2nd ed, Oxford University Press, 2012.
108-
"""
109-
110-
self.n_states, self.n_posdef, self.n_endog = R.shape[-2], R.shape[-1], Z.shape[-2]
111-
self.eye_states = pt.eye(self.n_states)
112-
self.eye_posdef = pt.eye(self.n_posdef)
113-
self.eye_endog = pt.eye(self.n_endog)
114-
11570
def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
11671
"""
11772
Apply any checks on validity of inputs. For most filters this is just the identity function.
@@ -141,10 +96,10 @@ def add_check_on_time_varying_shapes(
14196
list[TensorVariable]
14297
A list of tensors wrapped in an `Assert` `Op` that checks the shape of the 0th dimension on each is equal
14398
to the shape of the 0th dimension on the data.
144-
145-
# TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
146-
the Kalman filter, or in the StateSpaceModel, before passing into the KF?
14799
"""
100+
# TODO: The PytensorRepresentation object puts the time dimension last, should the reshaping happen here in
101+
# the Kalman filter, or in the StateSpaceModel, before passing into the KF?
102+
148103
params_with_assert = [
149104
assert_time_varying_dim_correct(param, pt.eq(param.shape[0], data.shape[0]))
150105
for param in sequence_params
@@ -166,7 +121,7 @@ def unpack_args(self, args) -> tuple:
166121
args = list(args)
167122
n_seq = len(self.seq_names)
168123
if n_seq == 0:
169-
return args
124+
return tuple(args)
170125

171126
# The first arg is always y
172127
y = args.pop(0)
@@ -202,7 +157,7 @@ def build_graph(
202157
return_updates=False,
203158
missing_fill_value=None,
204159
cov_jitter=None,
205-
) -> list[TensorVariable]:
160+
) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
206161
"""
207162
Construct the computation graph for the Kalman filter. See [1] for details.
208163
@@ -246,9 +201,11 @@ def build_graph(
246201

247202
self.mode = mode
248203
self.missing_fill_value = missing_fill_value
249-
self.initialize_eyes(R, Z)
250204
self.cov_jitter = cov_jitter
251205

206+
self.n_states, self.n_shocks = R.shape[-2:]
207+
self.n_endog = Z.shape[-2]
208+
252209
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
253210

254211
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
@@ -643,7 +600,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
643600
F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
644601

645602
K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T
646-
I_KZ = self.eye_states - K.dot(Z)
603+
I_KZ = pt.eye(self.n_states) - K.dot(Z)
647604

648605
a_filtered = a + K.dot(v)
649606
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
@@ -662,7 +619,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
662619
return a_filtered, P_filtered, y_hat, F, ll
663620

664621

665-
class CholeskyFilter(BaseFilter):
622+
class SquareRootFilter(BaseFilter):
666623
"""
667624
Kalman filter with Cholesky factorization
668625
@@ -686,7 +643,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
686643

687644
# If everything is missing, K = 0, IKZ = I
688645
K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T
689-
I_KZ = self.eye_states - K.dot(Z)
646+
I_KZ = pt.eye(self.n_states) - K.dot(Z)
690647

691648
a_filtered = a + K.dot(v)
692649
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
@@ -732,7 +689,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
732689
F = stabilize(Z.dot(PZT) + H, self.cov_jitter).ravel()
733690

734691
K = PZT / F
735-
I_KZ = self.eye_states - K.dot(Z)
692+
I_KZ = pt.eye(self.n_states) - K.dot(Z)
736693

737694
a_filtered = a + (K * v).ravel()
738695

@@ -743,123 +700,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
743700
return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll
744701

745702

746-
class SteadyStateFilter(BaseFilter):
747-
"""
748-
Kalman Filter using Steady State Covariance
749-
750-
This filter avoids the need to invert the covariance matrix of innovations at each time step by solving the
751-
Discrete Algebraic Riccati Equation associated with the filtering problem once and for all at initialization and
752-
uses the resulting steady-state covariance matrix in each step.
753-
754-
The innovation covariance matrix will always converge to the steady state value as T -> oo, so this filter will
755-
only have differences from the standard approach in the early steps (T < 10?). A process of "learning" is lost.
756-
"""
757-
758-
def build_graph(
759-
self,
760-
data,
761-
a0,
762-
P0,
763-
c,
764-
d,
765-
T,
766-
Z,
767-
R,
768-
H,
769-
Q,
770-
mode=None,
771-
return_updates=False,
772-
missing_fill_value=None,
773-
cov_jitter=None,
774-
) -> list[TensorVariable]:
775-
"""
776-
Need to override the base step to add an argument to self.update, passing F_inv at every step.
777-
"""
778-
if missing_fill_value is None:
779-
missing_fill_value = MISSING_FILL
780-
if cov_jitter is None:
781-
cov_jitter = JITTER_DEFAULT
782-
783-
self.mode = mode
784-
self.missing_fill_value = missing_fill_value
785-
self.cov_jitter = cov_jitter
786-
self.initialize_eyes(R, Z)
787-
788-
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
789-
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
790-
params, PARAM_NAMES
791-
)
792-
self.seq_names = seq_names
793-
self.non_seq_names = non_seq_names
794-
c, d, T, Z, R, H, Q = params
795-
796-
if len(sequences) > 0:
797-
assert ValueError(
798-
"All system matrices must be time-invariant to use the SteadyStateFilter"
799-
)
800-
801-
P_steady = solve_discrete_are(T.T, Z.T, matrix_dot(R, Q, R.T), H)
802-
F = matrix_dot(Z, P_steady, Z.T) + H
803-
F_inv = pt.linalg.solve(F, pt.eye(F.shape[0]), assume_a="pos", check_finite=False)
804-
805-
results, updates = pytensor.scan(
806-
self.kalman_step,
807-
sequences=[data],
808-
outputs_info=[None, a0, None, None, P_steady, None, None],
809-
non_sequences=[c, d, F_inv, T, Z, R, H, Q],
810-
name="forward_kalman_pass",
811-
mode=get_mode(self.mode),
812-
)
813-
814-
return self._postprocess_scan_results(results, a0, P0, n=data.shape[0])
815-
816-
def update(self, a, P, c, d, F_inv, y, Z, H, all_nan_flag):
817-
y_hat = Z.dot(a) + d
818-
v = y - y_hat
819-
820-
PZT = P.dot(Z.T)
821-
822-
F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
823-
K = PZT.dot(F_inv)
824-
825-
I_KZ = self.eye_states - K.dot(Z)
826-
827-
a_filtered = a + K.dot(v)
828-
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
829-
830-
inner_term = matrix_dot(v.T, F_inv, v)
831-
ll = pt.switch(
832-
all_nan_flag,
833-
0.0,
834-
-0.5 * (MVN_CONST + pt.log(pt.linalg.det(F)) + inner_term).ravel()[0],
835-
)
836-
837-
return a_filtered, P_filtered, y_hat, F, ll
838-
839-
def kalman_step(self, y, a, P, c, d, F_inv, T, Z, R, H, Q):
840-
"""
841-
Need to override the base step to add an argument to self.update, passing F_inv at every step.
842-
"""
843-
844-
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
845-
a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update(
846-
y=y_masked,
847-
a=a,
848-
P=P,
849-
c=c,
850-
d=d,
851-
F_inv=F_inv,
852-
Z=Z_masked,
853-
H=H_masked,
854-
all_nan_flag=all_nan_flag,
855-
)
856-
857-
P_filtered = stabilize(P_filtered, self.cov_jitter)
858-
a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
859-
860-
return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll
861-
862-
863703
class UnivariateFilter(BaseFilter):
864704
"""
865705
The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two

tests/statespace/test_kalman_filter.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
from numpy.testing import assert_allclose, assert_array_less
77

88
from pymc_experimental.statespace.filters import (
9-
CholeskyFilter,
109
KalmanSmoother,
1110
SingleTimeseriesFilter,
11+
SquareRootFilter,
1212
StandardFilter,
13-
SteadyStateFilter,
1413
UnivariateFilter,
1514
)
1615
from pymc_experimental.statespace.filters.kalman_filter import BaseFilter
@@ -33,25 +32,22 @@
3332
RTOL = 1e-6 if floatX.endswith("64") else 1e-3
3433

3534
standard_inout = initialize_filter(StandardFilter())
36-
cholesky_inout = initialize_filter(CholeskyFilter())
35+
cholesky_inout = initialize_filter(SquareRootFilter())
3736
univariate_inout = initialize_filter(UnivariateFilter())
3837
single_inout = initialize_filter(SingleTimeseriesFilter())
39-
steadystate_inout = initialize_filter(SteadyStateFilter())
4038

4139
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
4240
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
4341
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
4442
f_single_ts = pytensor.function(*single_inout, on_unused_input="ignore")
45-
f_steady = pytensor.function(*steadystate_inout, on_unused_input="ignore")
4643

47-
filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts, f_steady]
44+
filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts]
4845

4946
filter_names = [
5047
"StandardFilter",
5148
"CholeskyFilter",
5249
"UnivariateFilter",
5350
"SingleTimeSeriesFilter",
54-
"SteadyStateFilter",
5551
]
5652

5753
output_names = [
@@ -247,8 +243,7 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
247243
assert_allclose(filtered[-1], smoothed[-1])
248244

249245

250-
# TODO: These tests omit the SteadyStateFilter, because it gives different results to StatsModels (reason to dump it?)
251-
@pytest.mark.parametrize("filter_func", filter_funcs[:-1], ids=filter_names[:-1])
246+
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
252247
@pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
253248
@pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
254249
def test_filters_match_statsmodel_output(filter_func, n_missing, rng):
@@ -320,7 +315,7 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
320315

321316
@pytest.mark.parametrize(
322317
"filter",
323-
[StandardFilter, SingleTimeseriesFilter, CholeskyFilter],
318+
[StandardFilter, SingleTimeseriesFilter, SquareRootFilter],
324319
ids=["standard", "single_ts", "cholesky"],
325320
)
326321
def test_kalman_filter_jax(filter):

0 commit comments

Comments
 (0)