Skip to content

Commit 51d5f5a

Browse files
committed
Working batched Kalman filter and smoother
1 parent ca71070 commit 51d5f5a

File tree

6 files changed

+1102
-80
lines changed

6 files changed

+1102
-80
lines changed

notebooks/batch-examples.ipynb

Lines changed: 853 additions & 3 deletions
Large diffs are not rendered by default.

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC
2+
from functools import partial
23

34
import numpy as np
45
import pytensor
@@ -9,14 +10,13 @@
910
from pytensor.raise_op import Assert
1011
from pytensor.tensor import TensorVariable
1112
from pytensor.tensor.slinalg import solve_triangular
12-
from pytensor.graph.replace import vectorize_graph
1313

1414
from pymc_extras.statespace.filters.utilities import (
1515
quad_form_sym,
1616
split_vars_into_seq_and_nonseq,
1717
stabilize,
1818
)
19-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
19+
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL, ALL_KF_OUTPUT_NAMES
2020

2121
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2222
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
@@ -65,22 +65,56 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
6565
"""
6666
return data, a0, P0, c, d, T, Z, R, H, Q
6767

68-
def has_batched_input(self, data, a0, P0, c, d, T, Z, R, H, Q):
69-
"""
70-
Check if any of the inputs are batched.
71-
"""
72-
return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q]))
73-
74-
def get_dummy_core_inputs(self, data, a0, P0, c, d, T, Z, R, H, Q):
75-
"""
76-
Get dummy inputs for the core parameters.
77-
"""
78-
out = []
79-
for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM):
80-
out.append(
81-
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
82-
)
83-
return out
68+
def _make_gufunc_signature(self, inputs):
69+
states = "s"
70+
obs = "p"
71+
exog = "r"
72+
time = "t"
73+
74+
matrix_to_shape = {
75+
"data": (time, obs),
76+
"a0": (states,),
77+
"x0": (states,),
78+
"P0": (states, states),
79+
"c": (states,),
80+
"d": (obs,),
81+
"T": (states, states),
82+
"Z": (obs, states),
83+
"R": (states, exog),
84+
"H": (obs, obs),
85+
"Q": (exog, exog),
86+
"filtered_states": (time, states),
87+
"filtered_covariances": (time, states, states),
88+
"predicted_states": (time, states),
89+
"predicted_covariances": (time, states, states),
90+
"observed_states": (time, obs),
91+
"observed_covariances": (time, obs, obs),
92+
"smoothed_states": (time, states),
93+
"smoothed_covariances": (time, states, states),
94+
"loglike_obs": (time,),
95+
}
96+
input_shapes = []
97+
output_shapes = []
98+
99+
for matrix in inputs:
100+
name = matrix.name
101+
input_shapes.append(matrix_to_shape[name])
102+
103+
for name in [
104+
"filtered_states",
105+
"predicted_states",
106+
"smoothed_states",
107+
"filtered_covariances",
108+
"predicted_covariances",
109+
"smoothed_covariances",
110+
"loglike_obs",
111+
]:
112+
output_shapes.append(matrix_to_shape[name])
113+
114+
input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes])
115+
output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes])
116+
117+
return f"{input_signature} -> {output_signature}"
84118

85119
@staticmethod
86120
def add_check_on_time_varying_shapes(
@@ -150,7 +184,7 @@ def unpack_args(self, args) -> tuple:
150184

151185
return y, a0, P0, c, d, T, Z, R, H, Q
152186

153-
def build_graph(
187+
def _build_graph(
154188
self,
155189
data,
156190
a0,
@@ -206,18 +240,13 @@ def build_graph(
206240

207241
self.missing_fill_value = missing_fill_value
208242
self.cov_jitter = cov_jitter
209-
is_batched = self.has_batched_input(data, a0, P0, c, d, T, Z, R, H, Q)
210243

211244
[R_shape] = constant_fold([R.shape], raise_not_constant=False)
212245
[Z_shape] = constant_fold([Z.shape], raise_not_constant=False)
213246

214247
self.n_states, self.n_shocks = R_shape[-2:]
215248
self.n_endog = Z_shape[-2]
216249

217-
if is_batched:
218-
batched_inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
219-
data, a0, P0, c, d, T, Z, R, H, Q = self.get_dummy_core_inputs(*batched_inputs)
220-
221250
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
222251

223252
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
@@ -241,15 +270,47 @@ def build_graph(
241270

242271
filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
243272

244-
if is_batched:
245-
vec_subs = dict(zip([data, a0, P0, c, d, T, Z, R, H, Q], batched_inputs))
246-
filter_results = vectorize_graph(filter_results, vec_subs)
247-
248273
if return_updates:
249274
return filter_results, updates
250275

251276
return filter_results
252277

278+
def build_graph(
279+
self,
280+
data,
281+
a0,
282+
P0,
283+
c,
284+
d,
285+
T,
286+
Z,
287+
R,
288+
H,
289+
Q,
290+
mode=None,
291+
return_updates=False,
292+
missing_fill_value=None,
293+
cov_jitter=None,
294+
) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
295+
"""
296+
Build the vectorized computation graph for the Kalman filter.
297+
"""
298+
signature = self._make_gufunc_signature(
299+
[data, a0, P0, c, d, T, Z, R, H, Q],
300+
)
301+
fn = partial(
302+
self._build_graph,
303+
mode=mode,
304+
return_updates=return_updates,
305+
missing_fill_value=missing_fill_value,
306+
cov_jitter=cov_jitter,
307+
)
308+
filter_outputs = pt.vectorize(fn, signature=signature)(data, a0, P0, c, d, T, Z, R, H, Q)
309+
for output, name in zip(filter_outputs, ALL_KF_OUTPUT_NAMES):
310+
output.name = name
311+
312+
return filter_outputs
313+
253314
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
254315
"""
255316
Transform the values returned by the Kalman Filter scan into a form expected by users. In particular:

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import pytensor
22
import pytensor.tensor as pt
3-
3+
from functools import partial
4+
from pytensor.compile import get_mode
45
from pytensor.tensor.nlinalg import matrix_dot
5-
from pytensor.graph.replace import vectorize_graph
66
from pymc_extras.statespace.filters.utilities import (
77
quad_form_sym,
88
split_vars_into_seq_and_nonseq,
@@ -63,40 +63,57 @@ def unpack_args(self, args):
6363

6464
return a, P, a_smooth, P_smooth, T, R, Q
6565

66-
def has_batched_input(self, T, R, Q, filtered_states, filtered_covariances):
67-
"""
68-
Check if any of the inputs are batched.
69-
"""
70-
return any(
71-
x.ndim > SMOOTHER_CORE_NDIM[i]
72-
for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances])
73-
)
74-
75-
def get_dummy_core_inputs(self, T, R, Q, filtered_states, filtered_covariances):
76-
"""
77-
Get dummy inputs for the core parameters.
78-
"""
79-
out = []
80-
for x, core_ndim in zip(
81-
[T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM
82-
):
83-
out.append(
84-
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
85-
)
86-
return out
87-
88-
def build_graph(
89-
self, T, R, Q, filtered_states, filtered_covariances, cov_jitter=JITTER_DEFAULT
66+
def _make_gufunc_signature(self, inputs):
67+
states = "s"
68+
obs = "p"
69+
exog = "r"
70+
time = "t"
71+
72+
matrix_to_shape = {
73+
"data": (time, obs),
74+
"a0": (states,),
75+
"x0": (states,),
76+
"P0": (states, states),
77+
"c": (states,),
78+
"d": (obs,),
79+
"T": (states, states),
80+
"Z": (obs, states),
81+
"R": (states, exog),
82+
"H": (obs, obs),
83+
"Q": (exog, exog),
84+
"filtered_states": (time, states),
85+
"filtered_covariances": (time, states, states),
86+
"predicted_states": (time, states),
87+
"predicted_covariances": (time, states, states),
88+
"observed_states": (time, obs),
89+
"observed_covariances": (time, obs, obs),
90+
"smoothed_states": (time, states),
91+
"smoothed_covariances": (time, states, states),
92+
"loglike_obs": (time,),
93+
}
94+
input_shapes = []
95+
output_shapes = []
96+
97+
for matrix in inputs:
98+
name = matrix.name
99+
input_shapes.append(matrix_to_shape[name])
100+
101+
for name in [
102+
"smoothed_states",
103+
"smoothed_covariances",
104+
]:
105+
output_shapes.append(matrix_to_shape[name])
106+
107+
input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes])
108+
output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes])
109+
110+
return f"{input_signature} -> {output_signature}"
111+
112+
def _build_graph(
113+
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
90114
):
91115
self.cov_jitter = cov_jitter
92116

93-
is_batched = self.has_batched_input(T, R, Q, filtered_states, filtered_covariances)
94-
if is_batched:
95-
batched_inputs = [T, R, Q, filtered_states, filtered_covariances]
96-
T, R, Q, filtered_states, filtered_covariances = self.get_dummy_core_inputs(
97-
*batched_inputs
98-
)
99-
100117
n, k = filtered_states.type.shape
101118

102119
a_last = pt.specify_shape(filtered_states[-1], (k,))
@@ -125,18 +142,28 @@ def build_graph(
125142
smoothed_covariances = pt.concatenate(
126143
[smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
127144
)
128-
smoothed_states.dprint()
129-
if is_batched:
130-
vec_subs = dict(zip([T, R, Q, filtered_states, filtered_covariances], batched_inputs))
131-
smoothed_states, smoothed_covariances = vectorize_graph(
132-
[smoothed_states, smoothed_covariances], vec_subs
133-
)
134145

135146
smoothed_states.name = "smoothed_states"
136147
smoothed_covariances.name = "smoothed_covariances"
137148

138149
return smoothed_states, smoothed_covariances
139150

151+
def build_graph(
152+
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
153+
):
154+
"""
155+
Build the vectorized computation graph for the Kalman smoother.
156+
"""
157+
signature = self._make_gufunc_signature(
158+
[T, R, Q, filtered_states, filtered_covariances],
159+
)
160+
fn = partial(
161+
self._build_graph,
162+
mode=mode,
163+
cov_jitter=cov_jitter,
164+
)
165+
return pt.vectorize(fn, signature=signature)(T, R, Q, filtered_states, filtered_covariances)
166+
140167
def smoother_step(self, *args):
141168
a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)
142169
a_hat, P_hat = self.predict(a, P, T, R, Q)

pymc_extras/statespace/filters/utilities.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
from pytensor.tensor.nlinalg import matrix_dot
44

5-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED
5+
from pymc_extras.statespace.utils.constants import (
6+
JITTER_DEFAULT,
7+
NEVER_TIME_VARYING,
8+
VECTOR_VALUED,
9+
)
10+
11+
CORE_NDIM = (2, 1, 2, 1, 1, 2, 2, 2, 2, 2)
12+
SMOOTHER_CORE_NDIM = (2, 2, 2, 2, 3)
613

714

815
def decide_if_x_time_varies(x, name):
@@ -57,3 +64,40 @@ def stabilize(cov, jitter=JITTER_DEFAULT):
5764
def quad_form_sym(A, B):
5865
out = matrix_dot(A, B, A.T)
5966
return 0.5 * (out + out.T)
67+
68+
69+
def has_batched_input_smoother(T, R, Q, filtered_states, filtered_covariances):
70+
"""
71+
Check if any of the inputs are batched.
72+
"""
73+
return any(
74+
x.ndim > SMOOTHER_CORE_NDIM[i]
75+
for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances])
76+
)
77+
78+
79+
def get_dummy_core_inputs_smoother(T, R, Q, filtered_states, filtered_covariances):
80+
"""
81+
Get dummy inputs for the core parameters.
82+
"""
83+
out = []
84+
for x, core_ndim in zip([T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM):
85+
out.append(pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]))
86+
return out
87+
88+
89+
def has_batched_input_filter(data, a0, P0, c, d, T, Z, R, H, Q):
90+
"""
91+
Check if any of the inputs are batched.
92+
"""
93+
return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q]))
94+
95+
96+
def get_dummy_core_inputs_filter(data, a0, P0, c, d, T, Z, R, H, Q):
97+
"""
98+
Get dummy inputs for the core parameters.
99+
"""
100+
out = []
101+
for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM):
102+
out.append(pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]))
103+
return out

pymc_extras/statespace/utils/constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@
4747
SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
4848
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
4949

50+
ALL_KF_OUTPUT_NAMES = [
51+
"filtered_states",
52+
"predicted_states",
53+
"observed_states",
54+
"filtered_covariances",
55+
"predicted_covariances",
56+
"observed_covariances",
57+
"loglike_obs",
58+
]
59+
5060
MATRIX_DIMS = {
5161
"x0": (ALL_STATE_DIM,),
5262
"P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),

0 commit comments

Comments
 (0)