Skip to content

Commit c333b69

Browse files
committed
Working batched Kalman filter and smoother
1 parent cc9f750 commit c333b69

File tree

6 files changed

+1100
-79
lines changed

6 files changed

+1100
-79
lines changed

notebooks/batch-examples.ipynb

Lines changed: 853 additions & 3 deletions
Large diffs are not rendered by default.

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC
2+
from functools import partial
23

34
import numpy as np
45
import pytensor
@@ -10,14 +11,13 @@
1011
from pytensor.raise_op import Assert
1112
from pytensor.tensor import TensorVariable
1213
from pytensor.tensor.slinalg import solve_triangular
13-
from pytensor.graph.replace import vectorize_graph
1414

1515
from pymc_extras.statespace.filters.utilities import (
1616
quad_form_sym,
1717
split_vars_into_seq_and_nonseq,
1818
stabilize,
1919
)
20-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
20+
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL, ALL_KF_OUTPUT_NAMES
2121

2222
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2323
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
@@ -75,22 +75,56 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
7575
"""
7676
return data, a0, P0, c, d, T, Z, R, H, Q
7777

78-
def has_batched_input(self, data, a0, P0, c, d, T, Z, R, H, Q):
79-
"""
80-
Check if any of the inputs are batched.
81-
"""
82-
return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q]))
83-
84-
def get_dummy_core_inputs(self, data, a0, P0, c, d, T, Z, R, H, Q):
85-
"""
86-
Get dummy inputs for the core parameters.
87-
"""
88-
out = []
89-
for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM):
90-
out.append(
91-
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
92-
)
93-
return out
78+
def _make_gufunc_signature(self, inputs):
79+
states = "s"
80+
obs = "p"
81+
exog = "r"
82+
time = "t"
83+
84+
matrix_to_shape = {
85+
"data": (time, obs),
86+
"a0": (states,),
87+
"x0": (states,),
88+
"P0": (states, states),
89+
"c": (states,),
90+
"d": (obs,),
91+
"T": (states, states),
92+
"Z": (obs, states),
93+
"R": (states, exog),
94+
"H": (obs, obs),
95+
"Q": (exog, exog),
96+
"filtered_states": (time, states),
97+
"filtered_covariances": (time, states, states),
98+
"predicted_states": (time, states),
99+
"predicted_covariances": (time, states, states),
100+
"observed_states": (time, obs),
101+
"observed_covariances": (time, obs, obs),
102+
"smoothed_states": (time, states),
103+
"smoothed_covariances": (time, states, states),
104+
"loglike_obs": (time,),
105+
}
106+
input_shapes = []
107+
output_shapes = []
108+
109+
for matrix in inputs:
110+
name = matrix.name
111+
input_shapes.append(matrix_to_shape[name])
112+
113+
for name in [
114+
"filtered_states",
115+
"predicted_states",
116+
"smoothed_states",
117+
"filtered_covariances",
118+
"predicted_covariances",
119+
"smoothed_covariances",
120+
"loglike_obs",
121+
]:
122+
output_shapes.append(matrix_to_shape[name])
123+
124+
input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes])
125+
output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes])
126+
127+
return f"{input_signature} -> {output_signature}"
94128

95129
@staticmethod
96130
def add_check_on_time_varying_shapes(
@@ -160,7 +194,7 @@ def unpack_args(self, args) -> tuple:
160194

161195
return y, a0, P0, c, d, T, Z, R, H, Q
162196

163-
def build_graph(
197+
def _build_graph(
164198
self,
165199
data,
166200
a0,
@@ -221,18 +255,13 @@ def build_graph(
221255
self.mode = mode
222256
self.missing_fill_value = missing_fill_value
223257
self.cov_jitter = cov_jitter
224-
is_batched = self.has_batched_input(data, a0, P0, c, d, T, Z, R, H, Q)
225258

226259
[R_shape] = constant_fold([R.shape], raise_not_constant=False)
227260
[Z_shape] = constant_fold([Z.shape], raise_not_constant=False)
228261

229262
self.n_states, self.n_shocks = R_shape[-2:]
230263
self.n_endog = Z_shape[-2]
231264

232-
if is_batched:
233-
batched_inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
234-
data, a0, P0, c, d, T, Z, R, H, Q = self.get_dummy_core_inputs(*batched_inputs)
235-
236265
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
237266

238267
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
@@ -257,15 +286,47 @@ def build_graph(
257286

258287
filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
259288

260-
if is_batched:
261-
vec_subs = dict(zip([data, a0, P0, c, d, T, Z, R, H, Q], batched_inputs))
262-
filter_results = vectorize_graph(filter_results, vec_subs)
263-
264289
if return_updates:
265290
return filter_results, updates
266291

267292
return filter_results
268293

294+
def build_graph(
295+
self,
296+
data,
297+
a0,
298+
P0,
299+
c,
300+
d,
301+
T,
302+
Z,
303+
R,
304+
H,
305+
Q,
306+
mode=None,
307+
return_updates=False,
308+
missing_fill_value=None,
309+
cov_jitter=None,
310+
) -> list[TensorVariable] | tuple[list[TensorVariable], dict]:
311+
"""
312+
Build the vectorized computation graph for the Kalman filter.
313+
"""
314+
signature = self._make_gufunc_signature(
315+
[data, a0, P0, c, d, T, Z, R, H, Q],
316+
)
317+
fn = partial(
318+
self._build_graph,
319+
mode=mode,
320+
return_updates=return_updates,
321+
missing_fill_value=missing_fill_value,
322+
cov_jitter=cov_jitter,
323+
)
324+
filter_outputs = pt.vectorize(fn, signature=signature)(data, a0, P0, c, d, T, Z, R, H, Q)
325+
for output, name in zip(filter_outputs, ALL_KF_OUTPUT_NAMES):
326+
output.name = name
327+
328+
return filter_outputs
329+
269330
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
270331
"""
271332
Transform the values returned by the Kalman Filter scan into a form expected by users. In particular:

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import pytensor
22
import pytensor.tensor as pt
3-
3+
from functools import partial
44
from pytensor.compile import get_mode
55
from pytensor.tensor.nlinalg import matrix_dot
6-
from pytensor.graph.replace import vectorize_graph
76
from pymc_extras.statespace.filters.utilities import (
87
quad_form_sym,
98
split_vars_into_seq_and_nonseq,
@@ -65,41 +64,58 @@ def unpack_args(self, args):
6564

6665
return a, P, a_smooth, P_smooth, T, R, Q
6766

68-
def has_batched_input(self, T, R, Q, filtered_states, filtered_covariances):
69-
"""
70-
Check if any of the inputs are batched.
71-
"""
72-
return any(
73-
x.ndim > SMOOTHER_CORE_NDIM[i]
74-
for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances])
75-
)
76-
77-
def get_dummy_core_inputs(self, T, R, Q, filtered_states, filtered_covariances):
78-
"""
79-
Get dummy inputs for the core parameters.
80-
"""
81-
out = []
82-
for x, core_ndim in zip(
83-
[T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM
84-
):
85-
out.append(
86-
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
87-
)
88-
return out
89-
90-
def build_graph(
67+
def _make_gufunc_signature(self, inputs):
68+
states = "s"
69+
obs = "p"
70+
exog = "r"
71+
time = "t"
72+
73+
matrix_to_shape = {
74+
"data": (time, obs),
75+
"a0": (states,),
76+
"x0": (states,),
77+
"P0": (states, states),
78+
"c": (states,),
79+
"d": (obs,),
80+
"T": (states, states),
81+
"Z": (obs, states),
82+
"R": (states, exog),
83+
"H": (obs, obs),
84+
"Q": (exog, exog),
85+
"filtered_states": (time, states),
86+
"filtered_covariances": (time, states, states),
87+
"predicted_states": (time, states),
88+
"predicted_covariances": (time, states, states),
89+
"observed_states": (time, obs),
90+
"observed_covariances": (time, obs, obs),
91+
"smoothed_states": (time, states),
92+
"smoothed_covariances": (time, states, states),
93+
"loglike_obs": (time,),
94+
}
95+
input_shapes = []
96+
output_shapes = []
97+
98+
for matrix in inputs:
99+
name = matrix.name
100+
input_shapes.append(matrix_to_shape[name])
101+
102+
for name in [
103+
"smoothed_states",
104+
"smoothed_covariances",
105+
]:
106+
output_shapes.append(matrix_to_shape[name])
107+
108+
input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes])
109+
output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes])
110+
111+
return f"{input_signature} -> {output_signature}"
112+
113+
def _build_graph(
91114
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
92115
):
93116
self.mode = mode
94117
self.cov_jitter = cov_jitter
95118

96-
is_batched = self.has_batched_input(T, R, Q, filtered_states, filtered_covariances)
97-
if is_batched:
98-
batched_inputs = [T, R, Q, filtered_states, filtered_covariances]
99-
T, R, Q, filtered_states, filtered_covariances = self.get_dummy_core_inputs(
100-
*batched_inputs
101-
)
102-
103119
n, k = filtered_states.type.shape
104120

105121
a_last = pt.specify_shape(filtered_states[-1], (k,))
@@ -129,18 +145,28 @@ def build_graph(
129145
smoothed_covariances = pt.concatenate(
130146
[smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
131147
)
132-
smoothed_states.dprint()
133-
if is_batched:
134-
vec_subs = dict(zip([T, R, Q, filtered_states, filtered_covariances], batched_inputs))
135-
smoothed_states, smoothed_covariances = vectorize_graph(
136-
[smoothed_states, smoothed_covariances], vec_subs
137-
)
138148

139149
smoothed_states.name = "smoothed_states"
140150
smoothed_covariances.name = "smoothed_covariances"
141151

142152
return smoothed_states, smoothed_covariances
143153

154+
def build_graph(
155+
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
156+
):
157+
"""
158+
Build the vectorized computation graph for the Kalman smoother.
159+
"""
160+
signature = self._make_gufunc_signature(
161+
[T, R, Q, filtered_states, filtered_covariances],
162+
)
163+
fn = partial(
164+
self._build_graph,
165+
mode=mode,
166+
cov_jitter=cov_jitter,
167+
)
168+
return pt.vectorize(fn, signature=signature)(T, R, Q, filtered_states, filtered_covariances)
169+
144170
def smoother_step(self, *args):
145171
a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)
146172
a_hat, P_hat = self.predict(a, P, T, R, Q)

pymc_extras/statespace/filters/utilities.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
from pytensor.tensor.nlinalg import matrix_dot
44

5-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED
5+
from pymc_extras.statespace.utils.constants import (
6+
JITTER_DEFAULT,
7+
NEVER_TIME_VARYING,
8+
VECTOR_VALUED,
9+
)
10+
11+
CORE_NDIM = (2, 1, 2, 1, 1, 2, 2, 2, 2, 2)
12+
SMOOTHER_CORE_NDIM = (2, 2, 2, 2, 3)
613

714

815
def decide_if_x_time_varies(x, name):
@@ -57,3 +64,40 @@ def stabilize(cov, jitter=JITTER_DEFAULT):
5764
def quad_form_sym(A, B):
5865
out = matrix_dot(A, B, A.T)
5966
return 0.5 * (out + out.T)
67+
68+
69+
def has_batched_input_smoother(T, R, Q, filtered_states, filtered_covariances):
70+
"""
71+
Check if any of the inputs are batched.
72+
"""
73+
return any(
74+
x.ndim > SMOOTHER_CORE_NDIM[i]
75+
for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances])
76+
)
77+
78+
79+
def get_dummy_core_inputs_smoother(T, R, Q, filtered_states, filtered_covariances):
80+
"""
81+
Get dummy inputs for the core parameters.
82+
"""
83+
out = []
84+
for x, core_ndim in zip([T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM):
85+
out.append(pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]))
86+
return out
87+
88+
89+
def has_batched_input_filter(data, a0, P0, c, d, T, Z, R, H, Q):
90+
"""
91+
Check if any of the inputs are batched.
92+
"""
93+
return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q]))
94+
95+
96+
def get_dummy_core_inputs_filter(data, a0, P0, c, d, T, Z, R, H, Q):
97+
"""
98+
Get dummy inputs for the core parameters.
99+
"""
100+
out = []
101+
for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM):
102+
out.append(pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]))
103+
return out

pymc_extras/statespace/utils/constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@
4747
SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
4848
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
4949

50+
ALL_KF_OUTPUT_NAMES = [
51+
"filtered_states",
52+
"predicted_states",
53+
"observed_states",
54+
"filtered_covariances",
55+
"predicted_covariances",
56+
"observed_covariances",
57+
"loglike_obs",
58+
]
59+
5060
MATRIX_DIMS = {
5161
"x0": (ALL_STATE_DIM,),
5262
"P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),

0 commit comments

Comments
 (0)