Skip to content

Commit 08995be

Browse files
Remove SingleTimeSeriesFilter
1 parent c3bc365 commit 08995be

File tree

4 files changed

+17
-80
lines changed

4 files changed

+17
-80
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pymc_experimental.statespace.core.representation import PytensorRepresentation
2020
from pymc_experimental.statespace.filters import (
2121
KalmanSmoother,
22-
SingleTimeseriesFilter,
2322
SquareRootFilter,
2423
StandardFilter,
2524
UnivariateFilter,
@@ -52,7 +51,6 @@
5251
FILTER_FACTORY = {
5352
"standard": StandardFilter,
5453
"univariate": UnivariateFilter,
55-
"single": SingleTimeseriesFilter,
5654
"cholesky": SquareRootFilter,
5755
}
5856

pymc_experimental/statespace/filters/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace
22
from pymc_experimental.statespace.filters.kalman_filter import (
3-
SingleTimeseriesFilter,
43
SquareRootFilter,
54
StandardFilter,
65
UnivariateFilter,
@@ -11,7 +10,6 @@
1110
"StandardFilter",
1211
"UnivariateFilter",
1312
"KalmanSmoother",
14-
"SingleTimeseriesFilter",
1513
"SquareRootFilter",
1614
"LinearGaussianStateSpace",
1715
]

pymc_experimental/statespace/filters/kalman_filter.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2222
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
2323

24-
assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional")
2524
assert_time_varying_dim_correct = Assert(
2625
"The first dimension of a time varying matrix (the time dimension) must be "
2726
"equal to the first dimension of the data (the time dimension)."
@@ -751,50 +750,12 @@ def square_sequnece(L):
751750
]
752751

753752

754-
class SingleTimeseriesFilter(BaseFilter):
755-
"""
756-
Kalman filter optimized for univariate timeseries
757-
758-
If there is only a single observed timeseries, regardless of the number of hidden states, there is no need to
759-
perform a matrix inversion anywhere in the filter.
760-
"""
761-
762-
# TODO: This class should eventually be made irrelevant by pytensor re-writes.
763-
def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
764-
"""
765-
Wrap the data in an `Assert` `Op` to ensure there is only one observed state.
766-
"""
767-
data = assert_data_is_1d(data, pt.eq(data.shape[1], 1))
768-
769-
return data, a0, P0, c, d, T, Z, R, H, Q
770-
771-
def update(self, a, P, y, d, Z, H, all_nan_flag):
772-
y_hat = d + Z.dot(a)
773-
v = y - y_hat.ravel()
774-
775-
PZT = P.dot(Z.T)
776-
777-
# F is scalar, K is a column vector
778-
F = stabilize(Z.dot(PZT) + H, self.cov_jitter).ravel()
779-
780-
K = PZT / F
781-
I_KZ = pt.eye(self.n_states) - K.dot(Z)
782-
783-
a_filtered = a + (K * v).ravel()
784-
785-
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
786-
787-
ll = pt.switch(all_nan_flag, 0.0, -0.5 * (MVN_CONST + pt.log(F) + v**2 / F)).ravel()[0]
788-
789-
return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll
790-
791-
792753
class UnivariateFilter(BaseFilter):
793754
"""
794755
The univariate kalman filter, described in [1], section 6.4.2, avoids inversion of the F matrix, as well as two
795756
matrix multiplications, at the cost of an additional loop. Note that the name doesn't mean there's only one
796-
observed time series, that's the SingleTimeSeries filter. This is called univariate because it updates the state
797-
mean and covariance matrices one variable at a time, using an inner-inner loop.
757+
observed time series. This is called univariate because it updates the state mean and covariance matrices one
758+
variable at a time, using an inner-inner loop.
798759
799760
This is useful when states are perfectly observed, because the F matrix can easily become degenerate in these cases.
800761

tests/statespace/test_kalman_filter.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from pymc_experimental.statespace.filters import (
99
KalmanSmoother,
10-
SingleTimeseriesFilter,
1110
SquareRootFilter,
1211
StandardFilter,
1312
UnivariateFilter,
@@ -34,20 +33,17 @@
3433
standard_inout = initialize_filter(StandardFilter())
3534
cholesky_inout = initialize_filter(SquareRootFilter())
3635
univariate_inout = initialize_filter(UnivariateFilter())
37-
single_inout = initialize_filter(SingleTimeseriesFilter())
3836

3937
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
4038
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
4139
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
42-
f_single_ts = pytensor.function(*single_inout, on_unused_input="ignore")
4340

44-
filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts]
41+
filter_funcs = [f_standard, f_cholesky, f_univariate]
4542

4643
filter_names = [
4744
"StandardFilter",
4845
"CholeskyFilter",
4946
"UnivariateFilter",
50-
"SingleTimeSeriesFilter",
5147
]
5248

5349
output_names = [
@@ -191,20 +187,12 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng):
191187
p, m, r, n = 5, 5, 1, 10
192188
inputs = make_test_inputs(p, m, r, n, rng)
193189

194-
if filter_name == "SingleTimeSeriesFilter":
195-
with pytest.raises(
196-
AssertionError,
197-
match="UnivariateTimeSeries filter requires data be at most 1-dimensional",
198-
):
199-
filter_func(*inputs)
200-
201-
else:
202-
outputs = filter_func(*inputs)
203-
for output_idx, name in enumerate(output_names):
204-
expected_output = get_expected_shape(name, p, m, r, n)
205-
assert (
206-
outputs[output_idx].shape == expected_output
207-
), f"Shape of {name} does not match expected"
190+
outputs = filter_func(*inputs)
191+
for output_idx, name in enumerate(output_names):
192+
expected_output = get_expected_shape(name, p, m, r, n)
193+
assert (
194+
outputs[output_idx].shape == expected_output
195+
), f"Shape of {name} does not match expected"
208196

209197

210198
@pytest.mark.parametrize(
@@ -215,20 +203,12 @@ def test_missing_data(filter_func, filter_name, p, rng):
215203
m, r, n = 5, 1, 10
216204
inputs = make_test_inputs(p, m, r, n, rng, missing_data=1)
217205

218-
if p > 1 and filter_name == "SingleTimeSeriesFilter":
219-
with pytest.raises(
220-
AssertionError,
221-
match="UnivariateTimeSeries filter requires data be at most 1-dimensional",
222-
):
223-
filter_func(*inputs)
224-
225-
else:
226-
outputs = filter_func(*inputs)
227-
for output_idx, name in enumerate(output_names):
228-
expected_output = get_expected_shape(name, p, m, r, n)
229-
assert (
230-
outputs[output_idx].shape == expected_output
231-
), f"Shape of {name} does not match expected"
206+
outputs = filter_func(*inputs)
207+
for output_idx, name in enumerate(output_names):
208+
expected_output = get_expected_shape(name, p, m, r, n)
209+
assert (
210+
outputs[output_idx].shape == expected_output
211+
), f"Shape of {name} does not match expected"
232212

233213

234214
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
@@ -323,8 +303,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
323303

324304
@pytest.mark.parametrize(
325305
"filter",
326-
[StandardFilter, SingleTimeseriesFilter, SquareRootFilter],
327-
ids=["standard", "single_ts", "cholesky"],
306+
[StandardFilter, SquareRootFilter],
307+
ids=["standard", "cholesky"],
328308
)
329309
def test_kalman_filter_jax(filter):
330310
pytest.importorskip("jax")

0 commit comments

Comments
 (0)