Skip to content

Commit c3bc365

Browse files
Use square root filter equations in SquareRootFilter
1 parent aedd3a7 commit c3bc365

File tree

3 files changed

+147
-59
lines changed

3 files changed

+147
-59
lines changed

pymc_experimental/statespace/filters/kalman_filter.py

Lines changed: 123 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytensor
55
import pytensor.tensor as pt
66

7+
from pymc.pytensorf import constant_fold
78
from pytensor.compile.mode import get_mode
89
from pytensor.graph.basic import Variable
910
from pytensor.raise_op import Assert
@@ -203,8 +204,11 @@ def build_graph(
203204
self.missing_fill_value = missing_fill_value
204205
self.cov_jitter = cov_jitter
205206

206-
self.n_states, self.n_shocks = R.shape[-2:]
207-
self.n_endog = Z.shape[-2]
207+
[R_shape] = constant_fold([R.shape], raise_not_constant=False)
208+
[Z_shape] = constant_fold([Z.shape], raise_not_constant=False)
209+
210+
self.n_states, self.n_shocks = R_shape[-2:]
211+
self.n_endog = Z_shape[-2]
208212

209213
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
210214

@@ -408,7 +412,7 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
408412

409413
@staticmethod
410414
def update(
411-
a, P, y, c, d, Z, H, all_nan_flag
415+
a, P, y, d, Z, H, all_nan_flag
412416
) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
413417
"""
414418
Perform the update step of the Kalman filter.
@@ -419,7 +423,7 @@ def update(
419423
.. math::
420424
421425
\begin{align}
422-
\\hat{y}_t &= Z_t a_{t | t-1} \\
426+
\\hat{y}_t &= Z_t a_{t | t-1} + d_t \\
423427
v_t &= y_t - \\hat{y}_t \\
424428
F_t &= Z_t P_{t | t-1} Z_t^T + H_t \\
425429
a_{t|t} &= a_{t | t-1} + P_{t | t-1} Z_t^T F_t^{-1} v_t \\
@@ -435,8 +439,6 @@ def update(
435439
The current covariance matrix estimate, conditioned on information up to time t-1.
436440
y : TensorVariable
437441
The observation data at time t.
438-
c : TensorVariable
439-
The matrix c.
440442
d : TensorVariable
441443
The matrix d.
442444
Z : TensorVariable
@@ -529,7 +531,7 @@ def kalman_step(self, *args) -> tuple:
529531
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
530532

531533
a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update(
532-
y=y_masked, a=a, c=c, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
534+
y=y_masked, a=a, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
533535
)
534536

535537
P_filtered = stabilize(P_filtered, self.cov_jitter)
@@ -545,7 +547,7 @@ class StandardFilter(BaseFilter):
545547
Basic Kalman Filter
546548
"""
547549

548-
def update(self, a, P, y, c, d, Z, H, all_nan_flag):
550+
def update(self, a, P, y, d, Z, H, all_nan_flag):
549551
"""
550552
Compute one-step forecasts for observed states conditioned on information up to, but not including, the current
551553
timestep, `y_hat`, along with the forcast covariance matrix, `F`. Marginalize over observed states to obtain
@@ -566,9 +568,6 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
566568
y : TensorVariable
567569
Observations at time t.
568570
569-
c : TensorVariable
570-
Latent state bias term.
571-
572571
d : TensorVariable
573572
Observed state bias term.
574573
@@ -628,38 +627,128 @@ class SquareRootFilter(BaseFilter):
628627
629628
"""
630629

631-
# TODO: Can the entire Kalman filter process be re-written, starting from P0_chol, so it's not necessary to compute
632-
# cholesky(F) at every iteration?
630+
def predict(self, a, P, c, T, R, Q):
631+
"""
632+
Compute one-step forecasts for the hidden states conditioned on information up to, but not including, the current
633+
timestep, `a_hat`, along with the forcast covariance matrix, `P_hat`.
634+
635+
.. warning::
636+
Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
637+
covariance matrix itself. The name `P` is kept for consistency with the superclass.
638+
"""
639+
# Rename P to P_chol for clarity
640+
P_chol = P
641+
642+
a_hat = T.dot(a) + c
643+
Q_chol = pt.linalg.cholesky(Q, lower=True)
644+
645+
M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T
646+
R_decomp = pt.linalg.qr(M, mode="r")
647+
P_chol_hat = R_decomp[: self.n_states, : self.n_states].T
648+
649+
return a_hat, P_chol_hat
650+
651+
def update(self, a, P, y, d, Z, H, all_nan_flag):
652+
"""
653+
Compute posterior estimates of the hidden state distributions conditioned on the observed data, up to and
654+
including the present timestep. Also compute the log-likelihood of the data given the one-step forecasts.
655+
656+
.. warning::
657+
Very important -- In this function, $P$ is the **cholesky factor** of the covariance matrix, not the
658+
covariance matrix itself. The name `P` is kept for consistency with the superclass.
659+
"""
660+
661+
# Rename P to P_chol for clarity
662+
P_chol = P
633663

634-
def update(self, a, P, y, c, d, Z, H, all_nan_flag):
635664
y_hat = Z.dot(a) + d
636665
v = y - y_hat
637666

638-
PZT = P.dot(Z.T)
667+
H_chol = pytensor.ifelse(pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True))
668+
669+
# The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
670+
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
671+
# [0, L_pred]]
672+
# The Schur decomposition of this matrix will be B (upper triangular). We are
673+
# more insterested in B^T:
674+
# Structure of B^T = [[chol(F), 0 ],
675+
# [K @ chol(F), chol(P_filtered)]
676+
zeros = pt.zeros((self.n_states, self.n_endog))
677+
upper = pt.horizontal_stack(H_chol, Z @ P_chol)
678+
lower = pt.horizontal_stack(zeros, P_chol)
679+
A_T = pt.vertical_stack(upper, lower)
680+
B = pt.linalg.qr(A_T.T, mode="r").T
681+
682+
F_chol = B[: self.n_endog, : self.n_endog]
683+
K_F_chol = B[self.n_endog :, : self.n_endog]
684+
P_chol_filtered = B[self.n_endog :, self.n_endog :]
685+
686+
def compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
687+
a_filtered = a + K_F_chol @ solve_triangular(F_chol, v, lower=True)
688+
689+
inner_term = solve_triangular(
690+
F_chol, solve_triangular(F_chol, v, lower=True), lower=True
691+
)
692+
loss = (v.T @ inner_term).ravel()
693+
694+
# abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
695+
logdet = 2 * pt.log(pt.abs(pt.diag(F_chol))).sum()
696+
697+
ll = -0.5 * (self.n_endog * (MVN_CONST + logdet) + loss)[0]
698+
699+
return [a_filtered, P_chol_filtered, ll]
700+
701+
def compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
702+
"""
703+
If F is zero (usually because there were no observations this period), then we want:
704+
K = 0, a = a, P = P, ll = 0
705+
"""
706+
return [a, P_chol, pt.zeros(())]
707+
708+
[a_filtered, P_chol_filtered, ll] = pytensor.ifelse(
709+
pt.eq(all_nan_flag, 1.0),
710+
compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
711+
compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
712+
)
639713

640-
# If everything is missing, F will be [[0]] and F_chol will raise an error, so add identity to avoid the error
641-
F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
642-
F_chol = pt.linalg.cholesky(F)
714+
a_filtered = pt.specify_shape(a_filtered, (self.n_states,))
715+
P_chol_filtered = pt.specify_shape(P_chol_filtered, (self.n_states, self.n_states))
643716

644-
# If everything is missing, K = 0, IKZ = I
645-
K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T
646-
I_KZ = pt.eye(self.n_states) - K.dot(Z)
717+
return a_filtered, P_chol_filtered, y_hat, F_chol, ll
647718

648-
a_filtered = a + K.dot(v)
649-
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
719+
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
720+
"""
721+
Convert the Cholesky factor of the covariance matrix back to the covariance matrix itself.
722+
"""
723+
results = super()._postprocess_scan_results(results, a0, P0, n)
724+
(
725+
filtered_states,
726+
predicted_states,
727+
observed_states,
728+
filtered_covariances_cholesky,
729+
predicted_covariances_cholesky,
730+
observed_covariances_cholesky,
731+
loglike_obs,
732+
) = results
650733

651-
inner_term = solve_triangular(F_chol.T, solve_triangular(F_chol, v))
652-
n = y.shape[0]
734+
def square_sequnece(L):
735+
X = pt.einsum("...ij,...kj->...ik", L, L.copy())
736+
X = pt.specify_shape(X, (n, self.n_states, self.n_states))
737+
return X
653738

654-
ll = pt.switch(
655-
all_nan_flag,
656-
0.0,
657-
(
658-
-0.5 * (n * MVN_CONST + (v.T @ inner_term).ravel()) - pt.log(pt.diag(F_chol)).sum()
659-
).ravel()[0],
660-
)
739+
filtered_covariances = square_sequnece(filtered_covariances_cholesky)
740+
predicted_covariances = square_sequnece(predicted_covariances_cholesky)
741+
observed_covariances = square_sequnece(observed_covariances_cholesky)
661742

662-
return a_filtered, P_filtered, y_hat, F, ll
743+
return [
744+
filtered_states,
745+
predicted_states,
746+
observed_states,
747+
filtered_covariances,
748+
predicted_covariances,
749+
observed_covariances,
750+
loglike_obs,
751+
]
663752

664753

665754
class SingleTimeseriesFilter(BaseFilter):
@@ -679,7 +768,7 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
679768

680769
return data, a0, P0, c, d, T, Z, R, H, Q
681770

682-
def update(self, a, P, y, c, d, Z, H, all_nan_flag):
771+
def update(self, a, P, y, d, Z, H, all_nan_flag):
683772
y_hat = d + Z.dot(a)
684773
v = y - y_hat.ravel()
685774

tests/statespace/test_kalman_filter.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565
def test_base_class_update_raises():
6666
filter = BaseFilter()
67-
inputs = [None] * 8
67+
inputs = [None] * 7
6868
with pytest.raises(NotImplementedError):
6969
filter.update(*inputs)
7070

@@ -214,6 +214,7 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng):
214214
def test_missing_data(filter_func, filter_name, p, rng):
215215
m, r, n = 5, 1, 10
216216
inputs = make_test_inputs(p, m, r, n, rng, missing_data=1)
217+
217218
if p > 1 and filter_name == "SingleTimeSeriesFilter":
218219
with pytest.raises(
219220
AssertionError,
@@ -243,11 +244,16 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
243244
assert_allclose(filtered[-1], smoothed[-1])
244245

245246

246-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
247+
@pytest.mark.parametrize(
248+
"filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names
249+
)
247250
@pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
248251
@pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
249-
def test_filters_match_statsmodel_output(filter_func, n_missing, rng):
250-
fit_sm_mod, inputs = nile_test_test_helper(rng, n_missing)
252+
def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng):
253+
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
254+
if filter_name == "CholeskyFilter":
255+
P0 = np.linalg.cholesky(P0)
256+
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
251257
outputs = filter_func(*inputs)
252258

253259
for output_idx, name in enumerate(output_names):
@@ -294,6 +300,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
294300
pytest.skip("Univariate filter not stable at half precision without measurement error")
295301

296302
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
303+
if filter_name == "CholeskyFilter":
304+
P0 = np.linalg.cholesky(P0)
297305

298306
H *= int(obs_noise)
299307
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
@@ -325,16 +333,7 @@ def test_kalman_filter_jax(filter):
325333
# TODO: Add UnivariateFilter to test; need to figure out the broadcasting issue when 2nd data dim is defined
326334

327335
p, m, r, n = 1, 5, 1, 10
328-
inputs, outputs = initialize_filter(filter(), mode="JAX")
329-
330-
# Shape of the data must be static for jax to know how long the scan is
331-
data = inputs.pop(0)
332-
data_specified = pt.specify_shape(data, (n, None))
333-
data_specified.name = "data"
334-
inputs = [data, *inputs]
335-
336-
outputs = pytensor.graph.clone_replace(outputs, {data: data_specified})
337-
336+
inputs, outputs = initialize_filter(filter(), mode="JAX", p=p, m=m, r=r, n=n)
338337
inputs_np = make_test_inputs(p, m, r, n, rng)
339338

340339
f_jax = get_jaxified_graph(inputs, outputs)

tests/statespace/utilities/test_helpers.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ def load_nile_test_data():
3434
return nile
3535

3636

37-
def initialize_filter(kfilter, mode=None):
37+
def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None):
3838
ksmoother = KalmanSmoother()
39-
data = pt.matrix(name="data", dtype=floatX)
40-
a0 = pt.vector(name="a0", dtype=floatX)
41-
P0 = pt.matrix(name="P0", dtype=floatX)
42-
c = pt.vector(name="c", dtype=floatX)
43-
d = pt.vector(name="d", dtype=floatX)
44-
Q = pt.matrix(name="Q", dtype=floatX)
45-
H = pt.matrix(name="H", dtype=floatX)
46-
T = pt.matrix(name="T", dtype=floatX)
47-
R = pt.matrix(name="R", dtype=floatX)
48-
Z = pt.matrix(name="Z", dtype=floatX)
39+
data = pt.tensor(name="data", dtype=floatX, shape=(n, p))
40+
a0 = pt.tensor(name="x0", dtype=floatX, shape=(m,))
41+
P0 = pt.tensor(name="P0", dtype=floatX, shape=(m, m))
42+
c = pt.tensor(name="c", dtype=floatX, shape=(m,))
43+
d = pt.tensor(name="d", dtype=floatX, shape=(p,))
44+
Q = pt.tensor(name="Q", dtype=floatX, shape=(r, r))
45+
H = pt.tensor(name="H", dtype=floatX, shape=(p, p))
46+
T = pt.tensor(name="T", dtype=floatX, shape=(m, m))
47+
R = pt.tensor(name="R", dtype=floatX, shape=(m, r))
48+
Z = pt.tensor(name="Z", dtype=floatX, shape=(p, m))
4949

5050
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
5151

0 commit comments

Comments
 (0)