Skip to content

Commit ca71070

Browse files
committed
Working with Filter, not with Smoother
1 parent d477511 commit ca71070

File tree

5 files changed

+153
-62
lines changed

5 files changed

+153
-62
lines changed

notebooks/batch-examples.ipynb

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
},
190190
{
191191
"cell_type": "code",
192-
"execution_count": 20,
192+
"execution_count": 11,
193193
"id": "1262c7d4",
194194
"metadata": {},
195195
"outputs": [],
@@ -199,7 +199,7 @@
199199
},
200200
{
201201
"cell_type": "code",
202-
"execution_count": 21,
202+
"execution_count": 12,
203203
"id": "2dcd3958",
204204
"metadata": {},
205205
"outputs": [
@@ -228,7 +228,7 @@
228228
},
229229
{
230230
"cell_type": "code",
231-
"execution_count": 22,
231+
"execution_count": 13,
232232
"id": "6f41344f",
233233
"metadata": {},
234234
"outputs": [],
@@ -238,7 +238,7 @@
238238
},
239239
{
240240
"cell_type": "code",
241-
"execution_count": 23,
241+
"execution_count": 14,
242242
"id": "44905b8a",
243243
"metadata": {},
244244
"outputs": [],
@@ -248,7 +248,7 @@
248248
},
249249
{
250250
"cell_type": "code",
251-
"execution_count": 24,
251+
"execution_count": 15,
252252
"id": "34fe01b8",
253253
"metadata": {},
254254
"outputs": [
@@ -258,7 +258,7 @@
258258
"(3, 10, 5)"
259259
]
260260
},
261-
"execution_count": 24,
261+
"execution_count": 15,
262262
"metadata": {},
263263
"output_type": "execute_result"
264264
}
@@ -270,7 +270,7 @@
270270
},
271271
{
272272
"cell_type": "code",
273-
"execution_count": 25,
273+
"execution_count": 16,
274274
"id": "f37efe79",
275275
"metadata": {},
276276
"outputs": [
@@ -288,7 +288,7 @@
288288
},
289289
{
290290
"cell_type": "code",
291-
"execution_count": 26,
291+
"execution_count": 17,
292292
"id": "7b45de74",
293293
"metadata": {},
294294
"outputs": [
@@ -298,7 +298,7 @@
298298
"(3, 10)"
299299
]
300300
},
301-
"execution_count": 26,
301+
"execution_count": 17,
302302
"metadata": {},
303303
"output_type": "execute_result"
304304
}
@@ -317,7 +317,7 @@
317317
},
318318
{
319319
"cell_type": "code",
320-
"execution_count": 27,
320+
"execution_count": 18,
321321
"id": "69519822",
322322
"metadata": {},
323323
"outputs": [],
@@ -327,17 +327,17 @@
327327
},
328328
{
329329
"cell_type": "code",
330-
"execution_count": 28,
330+
"execution_count": 19,
331331
"id": "3f745449",
332332
"metadata": {},
333333
"outputs": [
334334
{
335335
"name": "stdout",
336336
"output_type": "stream",
337337
"text": [
338-
"633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
339-
"1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
340-
"4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
338+
"675 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
339+
"1.64 ms ± 37.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
340+
"5.28 ms ± 424 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
341341
]
342342
}
343343
],

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.raise_op import Assert
1010
from pytensor.tensor import TensorVariable
1111
from pytensor.tensor.slinalg import solve_triangular
12+
from pytensor.graph.replace import vectorize_graph
1213

1314
from pymc_extras.statespace.filters.utilities import (
1415
quad_form_sym,
@@ -19,6 +20,7 @@
1920

2021
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2122
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
23+
CORE_NDIM = (2, 1, 2, 1, 1, 2, 2, 2, 2, 2)
2224

2325
assert_time_varying_dim_correct = Assert(
2426
"The first dimension of a time varying matrix (the time dimension) must be "
@@ -63,6 +65,23 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
6365
"""
6466
return data, a0, P0, c, d, T, Z, R, H, Q
6567

68+
def has_batched_input(self, data, a0, P0, c, d, T, Z, R, H, Q):
69+
"""
70+
Check if any of the inputs are batched.
71+
"""
72+
return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q]))
73+
74+
def get_dummy_core_inputs(self, data, a0, P0, c, d, T, Z, R, H, Q):
75+
"""
76+
Get dummy inputs for the core parameters.
77+
"""
78+
out = []
79+
for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM):
80+
out.append(
81+
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
82+
)
83+
return out
84+
6685
@staticmethod
6786
def add_check_on_time_varying_shapes(
6887
data: TensorVariable, sequence_params: list[TensorVariable]
@@ -187,13 +206,18 @@ def build_graph(
187206

188207
self.missing_fill_value = missing_fill_value
189208
self.cov_jitter = cov_jitter
209+
is_batched = self.has_batched_input(data, a0, P0, c, d, T, Z, R, H, Q)
190210

191211
[R_shape] = constant_fold([R.shape], raise_not_constant=False)
192212
[Z_shape] = constant_fold([Z.shape], raise_not_constant=False)
193213

194214
self.n_states, self.n_shocks = R_shape[-2:]
195215
self.n_endog = Z_shape[-2]
196216

217+
if is_batched:
218+
batched_inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
219+
data, a0, P0, c, d, T, Z, R, H, Q = self.get_dummy_core_inputs(*batched_inputs)
220+
197221
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
198222

199223
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
@@ -217,8 +241,13 @@ def build_graph(
217241

218242
filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
219243

244+
if is_batched:
245+
vec_subs = dict(zip([data, a0, P0, c, d, T, Z, R, H, Q], batched_inputs))
246+
filter_results = vectorize_graph(filter_results, vec_subs)
247+
220248
if return_updates:
221249
return filter_results, updates
250+
222251
return filter_results
223252

224253
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import pytensor.tensor as pt
33

44
from pytensor.tensor.nlinalg import matrix_dot
5-
5+
from pytensor.graph.replace import vectorize_graph
66
from pymc_extras.statespace.filters.utilities import (
77
quad_form_sym,
88
split_vars_into_seq_and_nonseq,
99
stabilize,
1010
)
1111
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT
1212

13+
SMOOTHER_CORE_NDIM = (2, 2, 2, 2, 3)
14+
1315

1416
class KalmanSmoother:
1517
"""
@@ -61,11 +63,40 @@ def unpack_args(self, args):
6163

6264
return a, P, a_smooth, P_smooth, T, R, Q
6365

66+
def has_batched_input(self, T, R, Q, filtered_states, filtered_covariances):
67+
"""
68+
Check if any of the inputs are batched.
69+
"""
70+
return any(
71+
x.ndim > SMOOTHER_CORE_NDIM[i]
72+
for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances])
73+
)
74+
75+
def get_dummy_core_inputs(self, T, R, Q, filtered_states, filtered_covariances):
76+
"""
77+
Get dummy inputs for the core parameters.
78+
"""
79+
out = []
80+
for x, core_ndim in zip(
81+
[T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM
82+
):
83+
out.append(
84+
pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:])
85+
)
86+
return out
87+
6488
def build_graph(
6589
self, T, R, Q, filtered_states, filtered_covariances, cov_jitter=JITTER_DEFAULT
6690
):
6791
self.cov_jitter = cov_jitter
6892

93+
is_batched = self.has_batched_input(T, R, Q, filtered_states, filtered_covariances)
94+
if is_batched:
95+
batched_inputs = [T, R, Q, filtered_states, filtered_covariances]
96+
T, R, Q, filtered_states, filtered_covariances = self.get_dummy_core_inputs(
97+
*batched_inputs
98+
)
99+
69100
n, k = filtered_states.type.shape
70101

71102
a_last = pt.specify_shape(filtered_states[-1], (k,))
@@ -94,6 +125,12 @@ def build_graph(
94125
smoothed_covariances = pt.concatenate(
95126
[smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
96127
)
128+
smoothed_states.dprint()
129+
if is_batched:
130+
vec_subs = dict(zip([T, R, Q, filtered_states, filtered_covariances], batched_inputs))
131+
smoothed_states, smoothed_covariances = vectorize_graph(
132+
[smoothed_states, smoothed_covariances], vec_subs
133+
)
97134

98135
smoothed_states.name = "smoothed_states"
99136
smoothed_covariances.name = "smoothed_covariances"

tests/statespace/filters/test_kalman_filter.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,22 @@
3131
RTOL = 1e-6 if floatX.endswith("64") else 1e-3
3232

3333
standard_inout = initialize_filter(StandardFilter())
34+
standard_inout_batched = initialize_filter(StandardFilter(), batched=True)
3435
cholesky_inout = initialize_filter(SquareRootFilter())
3536
univariate_inout = initialize_filter(UnivariateFilter())
3637

3738
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
39+
f_standard_batched = pytensor.function(*standard_inout_batched, on_unused_input="ignore")
3840
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
3941
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
4042

41-
filter_funcs = [f_standard, f_cholesky, f_univariate]
43+
filter_funcs = [f_standard, f_standard_batched] # , f_cholesky, f_univariate]
4244

4345
filter_names = [
4446
"StandardFilter",
45-
"CholeskyFilter",
46-
"UnivariateFilter",
47+
"StandardFilterBatched",
48+
# "CholeskyFilter",
49+
# "UnivariateFilter",
4750
]
4851

4952
output_names = [
@@ -65,17 +68,21 @@ def test_base_class_update_raises():
6568
filter.update(*inputs)
6669

6770

68-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
69-
def test_output_shapes_one_state_one_observed(filter_func, rng):
71+
@pytest.mark.parametrize(
72+
"filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names
73+
)
74+
def test_output_shapes_one_state_one_observed(filter_func, filter_name, rng):
75+
batch_size = 3 if "batched" in filter_name.lower() else 0
7076
p, m, r, n = 1, 1, 1, 10
71-
inputs = make_test_inputs(p, m, r, n, rng)
72-
outputs = filter_func(*inputs)
77+
inputs = make_test_inputs(p, m, r, n, rng, batch_size=batch_size)
78+
assert 0
79+
# outputs = filter_func(*inputs)
7380

7481
for output_idx, name in enumerate(output_names):
75-
expected_output = get_expected_shape(name, p, m, r, n)
76-
assert (
77-
outputs[output_idx].shape == expected_output
78-
), f"Shape of {name} does not match expected"
82+
expected_shape = get_expected_shape(name, p, m, r, n, batch_size)
83+
# assert outputs[output_idx].shape == expected_shape, (
84+
# f"Shape of {name} does not match expected"
85+
# )
7986

8087

8188
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -86,9 +93,9 @@ def test_output_shapes_when_all_states_are_stochastic(filter_func, rng):
8693
outputs = filter_func(*inputs)
8794
for output_idx, name in enumerate(output_names):
8895
expected_output = get_expected_shape(name, p, m, r, n)
89-
assert (
90-
outputs[output_idx].shape == expected_output
91-
), f"Shape of {name} does not match expected"
96+
assert outputs[output_idx].shape == expected_output, (
97+
f"Shape of {name} does not match expected"
98+
)
9299

93100

94101
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -99,9 +106,9 @@ def test_output_shapes_when_some_states_are_deterministic(filter_func, rng):
99106
outputs = filter_func(*inputs)
100107
for output_idx, name in enumerate(output_names):
101108
expected_output = get_expected_shape(name, p, m, r, n)
102-
assert (
103-
outputs[output_idx].shape == expected_output
104-
), f"Shape of {name} does not match expected"
109+
assert outputs[output_idx].shape == expected_output, (
110+
f"Shape of {name} does not match expected"
111+
)
105112

106113

107114
@pytest.fixture
@@ -161,9 +168,9 @@ def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng):
161168

162169
for output_idx, name in enumerate(output_names):
163170
expected_output = get_expected_shape(name, p, m, r, n)
164-
assert (
165-
outputs[output_idx].shape == expected_output
166-
), f"Shape of {name} does not match expected"
171+
assert outputs[output_idx].shape == expected_output, (
172+
f"Shape of {name} does not match expected"
173+
)
167174

168175

169176
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -175,9 +182,9 @@ def test_output_with_deterministic_observation_equation(filter_func, rng):
175182

176183
for output_idx, name in enumerate(output_names):
177184
expected_output = get_expected_shape(name, p, m, r, n)
178-
assert (
179-
outputs[output_idx].shape == expected_output
180-
), f"Shape of {name} does not match expected"
185+
assert outputs[output_idx].shape == expected_output, (
186+
f"Shape of {name} does not match expected"
187+
)
181188

182189

183190
@pytest.mark.parametrize(
@@ -190,9 +197,9 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng):
190197
outputs = filter_func(*inputs)
191198
for output_idx, name in enumerate(output_names):
192199
expected_output = get_expected_shape(name, p, m, r, n)
193-
assert (
194-
outputs[output_idx].shape == expected_output
195-
), f"Shape of {name} does not match expected"
200+
assert outputs[output_idx].shape == expected_output, (
201+
f"Shape of {name} does not match expected"
202+
)
196203

197204

198205
@pytest.mark.parametrize(
@@ -206,9 +213,9 @@ def test_missing_data(filter_func, filter_name, p, rng):
206213
outputs = filter_func(*inputs)
207214
for output_idx, name in enumerate(output_names):
208215
expected_output = get_expected_shape(name, p, m, r, n)
209-
assert (
210-
outputs[output_idx].shape == expected_output
211-
), f"Shape of {name} does not match expected"
216+
assert outputs[output_idx].shape == expected_output, (
217+
f"Shape of {name} does not match expected"
218+
)
212219

213220

214221
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)

0 commit comments

Comments
 (0)