Skip to content

Commit cc9f750

Browse files
committed
Working with Filter, not with Smoother
1 parent 7f6845e commit cc9f750

File tree

5 files changed

+153
-62
lines changed

5 files changed

+153
-62
lines changed

notebooks/batch-examples.ipynb

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
},
190190
{
191191
"cell_type": "code",
192-
"execution_count": 20,
192+
"execution_count": 11,
193193
"id": "1262c7d4",
194194
"metadata": {},
195195
"outputs": [],
@@ -199,7 +199,7 @@
199199
},
200200
{
201201
"cell_type": "code",
202-
"execution_count": 21,
202+
"execution_count": 12,
203203
"id": "2dcd3958",
204204
"metadata": {},
205205
"outputs": [
@@ -228,7 +228,7 @@
228228
},
229229
{
230230
"cell_type": "code",
231-
"execution_count": 22,
231+
"execution_count": 13,
232232
"id": "6f41344f",
233233
"metadata": {},
234234
"outputs": [],
@@ -238,7 +238,7 @@
238238
},
239239
{
240240
"cell_type": "code",
241-
"execution_count": 23,
241+
"execution_count": 14,
242242
"id": "44905b8a",
243243
"metadata": {},
244244
"outputs": [],
@@ -248,7 +248,7 @@
248248
},
249249
{
250250
"cell_type": "code",
251-
"execution_count": 24,
251+
"execution_count": 15,
252252
"id": "34fe01b8",
253253
"metadata": {},
254254
"outputs": [
@@ -258,7 +258,7 @@
258258
"(3, 10, 5)"
259259
]
260260
},
261-
"execution_count": 24,
261+
"execution_count": 15,
262262
"metadata": {},
263263
"output_type": "execute_result"
264264
}
@@ -270,7 +270,7 @@
270270
},
271271
{
272272
"cell_type": "code",
273-
"execution_count": 25,
273+
"execution_count": 16,
274274
"id": "f37efe79",
275275
"metadata": {},
276276
"outputs": [
@@ -288,7 +288,7 @@
288288
},
289289
{
290290
"cell_type": "code",
291-
"execution_count": 26,
291+
"execution_count": 17,
292292
"id": "7b45de74",
293293
"metadata": {},
294294
"outputs": [
@@ -298,7 +298,7 @@
298298
"(3, 10)"
299299
]
300300
},
301-
"execution_count": 26,
301+
"execution_count": 17,
302302
"metadata": {},
303303
"output_type": "execute_result"
304304
}
@@ -317,7 +317,7 @@
317317
},
318318
{
319319
"cell_type": "code",
320-
"execution_count": 27,
320+
"execution_count": 18,
321321
"id": "69519822",
322322
"metadata": {},
323323
"outputs": [],
@@ -327,17 +327,17 @@
327327
},
328328
{
329329
"cell_type": "code",
330-
"execution_count": 28,
330+
"execution_count": 19,
331331
"id": "3f745449",
332332
"metadata": {},
333333
"outputs": [
334334
{
335335
"name": "stdout",
336336
"output_type": "stream",
337337
"text": [
338-
"633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
339-
"1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
340-
"4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
338+
"675 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
339+
"1.64 ms ± 37.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
340+
"5.28 ms ± 424 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
341341
]
342342
}
343343
],

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.raise_op import Assert
1111
from pytensor.tensor import TensorVariable
1212
from pytensor.tensor.slinalg import solve_triangular
13+
from pytensor.graph.replace import vectorize_graph
1314

1415
from pymc_extras.statespace.filters.utilities import (
1516
quad_form_sym,
@@ -20,6 +21,7 @@
2021

2122
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2223
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
24+
CORE_NDIM = (2, 1, 2, 1, 1, 2, 2, 2, 2, 2)
2325

2426
assert_time_varying_dim_correct = Assert(
2527
"The first dimension of a time varying matrix (the time dimension) must be "
@@ -73,6 +75,23 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
7375
"""
7476
return data, a0, P0, c, d, T, Z, R, H, Q
7577

78+
def has_batched_input(self, data, a0, P0, c, d, T, Z, R, H, Q):
79+
"""
80+
Check if any of the inputs are batched.
81+
"""
82+
return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q]))
83+
84+
def get_dummy_core_inputs(self, data, a0, P0, c, d, T, Z, R, H, Q):
85+
"""
86+
Get dummy inputs for the core parameters.
87+
"""
88+
out = []
89+
for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM):
90+
out.append(
91+
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
92+
)
93+
return out
94+
7695
@staticmethod
7796
def add_check_on_time_varying_shapes(
7897
data: TensorVariable, sequence_params: list[TensorVariable]
@@ -202,13 +221,18 @@ def build_graph(
202221
self.mode = mode
203222
self.missing_fill_value = missing_fill_value
204223
self.cov_jitter = cov_jitter
224+
is_batched = self.has_batched_input(data, a0, P0, c, d, T, Z, R, H, Q)
205225

206226
[R_shape] = constant_fold([R.shape], raise_not_constant=False)
207227
[Z_shape] = constant_fold([Z.shape], raise_not_constant=False)
208228

209229
self.n_states, self.n_shocks = R_shape[-2:]
210230
self.n_endog = Z_shape[-2]
211231

232+
if is_batched:
233+
batched_inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
234+
data, a0, P0, c, d, T, Z, R, H, Q = self.get_dummy_core_inputs(*batched_inputs)
235+
212236
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
213237

214238
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
@@ -233,8 +257,13 @@ def build_graph(
233257

234258
filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
235259

260+
if is_batched:
261+
vec_subs = dict(zip([data, a0, P0, c, d, T, Z, R, H, Q], batched_inputs))
262+
filter_results = vectorize_graph(filter_results, vec_subs)
263+
236264
if return_updates:
237265
return filter_results, updates
266+
238267
return filter_results
239268

240269
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
from pytensor.compile import get_mode
55
from pytensor.tensor.nlinalg import matrix_dot
6-
6+
from pytensor.graph.replace import vectorize_graph
77
from pymc_extras.statespace.filters.utilities import (
88
quad_form_sym,
99
split_vars_into_seq_and_nonseq,
1010
stabilize,
1111
)
1212
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT
1313

14+
SMOOTHER_CORE_NDIM = (2, 2, 2, 2, 3)
15+
1416

1517
class KalmanSmoother:
1618
"""
@@ -63,12 +65,41 @@ def unpack_args(self, args):
6365

6466
return a, P, a_smooth, P_smooth, T, R, Q
6567

68+
def has_batched_input(self, T, R, Q, filtered_states, filtered_covariances):
69+
"""
70+
Check if any of the inputs are batched.
71+
"""
72+
return any(
73+
x.ndim > SMOOTHER_CORE_NDIM[i]
74+
for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances])
75+
)
76+
77+
def get_dummy_core_inputs(self, T, R, Q, filtered_states, filtered_covariances):
78+
"""
79+
Get dummy inputs for the core parameters.
80+
"""
81+
out = []
82+
for x, core_ndim in zip(
83+
[T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM
84+
):
85+
out.append(
86+
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
87+
)
88+
return out
89+
6690
def build_graph(
6791
self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT
6892
):
6993
self.mode = mode
7094
self.cov_jitter = cov_jitter
7195

96+
is_batched = self.has_batched_input(T, R, Q, filtered_states, filtered_covariances)
97+
if is_batched:
98+
batched_inputs = [T, R, Q, filtered_states, filtered_covariances]
99+
T, R, Q, filtered_states, filtered_covariances = self.get_dummy_core_inputs(
100+
*batched_inputs
101+
)
102+
72103
n, k = filtered_states.type.shape
73104

74105
a_last = pt.specify_shape(filtered_states[-1], (k,))
@@ -98,6 +129,12 @@ def build_graph(
98129
smoothed_covariances = pt.concatenate(
99130
[smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
100131
)
132+
smoothed_states.dprint()
133+
if is_batched:
134+
vec_subs = dict(zip([T, R, Q, filtered_states, filtered_covariances], batched_inputs))
135+
smoothed_states, smoothed_covariances = vectorize_graph(
136+
[smoothed_states, smoothed_covariances], vec_subs
137+
)
101138

102139
smoothed_states.name = "smoothed_states"
103140
smoothed_covariances.name = "smoothed_covariances"

tests/statespace/test_kalman_filter.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,22 @@
3131
RTOL = 1e-6 if floatX.endswith("64") else 1e-3
3232

3333
standard_inout = initialize_filter(StandardFilter())
34+
standard_inout_batched = initialize_filter(StandardFilter(), batched=True)
3435
cholesky_inout = initialize_filter(SquareRootFilter())
3536
univariate_inout = initialize_filter(UnivariateFilter())
3637

3738
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
39+
f_standard_batched = pytensor.function(*standard_inout_batched, on_unused_input="ignore")
3840
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
3941
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
4042

41-
filter_funcs = [f_standard, f_cholesky, f_univariate]
43+
filter_funcs = [f_standard, f_standard_batched] # , f_cholesky, f_univariate]
4244

4345
filter_names = [
4446
"StandardFilter",
45-
"CholeskyFilter",
46-
"UnivariateFilter",
47+
"StandardFilterBatched",
48+
# "CholeskyFilter",
49+
# "UnivariateFilter",
4750
]
4851

4952
output_names = [
@@ -65,17 +68,21 @@ def test_base_class_update_raises():
6568
filter.update(*inputs)
6669

6770

68-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
69-
def test_output_shapes_one_state_one_observed(filter_func, rng):
71+
@pytest.mark.parametrize(
72+
"filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names
73+
)
74+
def test_output_shapes_one_state_one_observed(filter_func, filter_name, rng):
75+
batch_size = 3 if "batched" in filter_name.lower() else 0
7076
p, m, r, n = 1, 1, 1, 10
71-
inputs = make_test_inputs(p, m, r, n, rng)
72-
outputs = filter_func(*inputs)
77+
inputs = make_test_inputs(p, m, r, n, rng, batch_size=batch_size)
78+
assert 0
79+
# outputs = filter_func(*inputs)
7380

7481
for output_idx, name in enumerate(output_names):
75-
expected_output = get_expected_shape(name, p, m, r, n)
76-
assert (
77-
outputs[output_idx].shape == expected_output
78-
), f"Shape of {name} does not match expected"
82+
expected_shape = get_expected_shape(name, p, m, r, n, batch_size)
83+
# assert outputs[output_idx].shape == expected_shape, (
84+
# f"Shape of {name} does not match expected"
85+
# )
7986

8087

8188
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -86,9 +93,9 @@ def test_output_shapes_when_all_states_are_stochastic(filter_func, rng):
8693
outputs = filter_func(*inputs)
8794
for output_idx, name in enumerate(output_names):
8895
expected_output = get_expected_shape(name, p, m, r, n)
89-
assert (
90-
outputs[output_idx].shape == expected_output
91-
), f"Shape of {name} does not match expected"
96+
assert outputs[output_idx].shape == expected_output, (
97+
f"Shape of {name} does not match expected"
98+
)
9299

93100

94101
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -99,9 +106,9 @@ def test_output_shapes_when_some_states_are_deterministic(filter_func, rng):
99106
outputs = filter_func(*inputs)
100107
for output_idx, name in enumerate(output_names):
101108
expected_output = get_expected_shape(name, p, m, r, n)
102-
assert (
103-
outputs[output_idx].shape == expected_output
104-
), f"Shape of {name} does not match expected"
109+
assert outputs[output_idx].shape == expected_output, (
110+
f"Shape of {name} does not match expected"
111+
)
105112

106113

107114
@pytest.fixture
@@ -161,9 +168,9 @@ def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng):
161168

162169
for output_idx, name in enumerate(output_names):
163170
expected_output = get_expected_shape(name, p, m, r, n)
164-
assert (
165-
outputs[output_idx].shape == expected_output
166-
), f"Shape of {name} does not match expected"
171+
assert outputs[output_idx].shape == expected_output, (
172+
f"Shape of {name} does not match expected"
173+
)
167174

168175

169176
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -175,9 +182,9 @@ def test_output_with_deterministic_observation_equation(filter_func, rng):
175182

176183
for output_idx, name in enumerate(output_names):
177184
expected_output = get_expected_shape(name, p, m, r, n)
178-
assert (
179-
outputs[output_idx].shape == expected_output
180-
), f"Shape of {name} does not match expected"
185+
assert outputs[output_idx].shape == expected_output, (
186+
f"Shape of {name} does not match expected"
187+
)
181188

182189

183190
@pytest.mark.parametrize(
@@ -190,9 +197,9 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng):
190197
outputs = filter_func(*inputs)
191198
for output_idx, name in enumerate(output_names):
192199
expected_output = get_expected_shape(name, p, m, r, n)
193-
assert (
194-
outputs[output_idx].shape == expected_output
195-
), f"Shape of {name} does not match expected"
200+
assert outputs[output_idx].shape == expected_output, (
201+
f"Shape of {name} does not match expected"
202+
)
196203

197204

198205
@pytest.mark.parametrize(
@@ -206,9 +213,9 @@ def test_missing_data(filter_func, filter_name, p, rng):
206213
outputs = filter_func(*inputs)
207214
for output_idx, name in enumerate(output_names):
208215
expected_output = get_expected_shape(name, p, m, r, n)
209-
assert (
210-
outputs[output_idx].shape == expected_output
211-
), f"Shape of {name} does not match expected"
216+
assert outputs[output_idx].shape == expected_output, (
217+
f"Shape of {name} does not match expected"
218+
)
212219

213220

214221
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)

0 commit comments

Comments
 (0)