Skip to content

Commit 6309245

Browse files
Merge pull request #290 from LAMPSPUC/sparse_filter
Add sparse kalman filter
2 parents 2ca5237 + cc6a454 commit 6309245

File tree

5 files changed

+337
-4
lines changed

5 files changed

+337
-4
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StateSpaceModels"
22
uuid = "99342f36-827c-5390-97c9-d7f9ee765c78"
33
authors = ["raphaelsaavedra <[email protected]>, guilhermebodin <[email protected]>, mariohsouto"]
4-
version = "0.5.18"
4+
version = "0.5.19"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -14,6 +14,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1414
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1515
SeasonalTrendLoess = "42fb36cb-998a-4034-bf40-4eee476c43a1"
1616
ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"
17+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1718
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1819
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1920

docs/src/manual.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ StateSpaceModels.jl lets users define tailor-made filters in an easy manner. TOD
128128
```@docs
129129
UnivariateKalmanFilter
130130
ScalarKalmanFilter
131-
StateSpaceModels.FilterOutput
131+
SparseUnivariateKalmanFilter
132+
FilterOutput
133+
SmootherOutput
132134
get_innovations
133135
get_innovations_variance
134136
get_filtered_state

src/StateSpaceModels.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ using Printf
1212
using Optim
1313
using OrderedCollections
1414
using RecipesBase
15-
using StatsBase
1615
using SeasonalTrendLoess
16+
using SparseArrays
17+
using StatsBase
1718

1819
abstract type StateSpaceModel end
1920

@@ -28,6 +29,7 @@ include("filters/univariate_kalman_filter.jl")
2829
include("filters/multivariate_kalman_filter.jl")
2930
include("filters/scalar_kalman_filter.jl")
3031
include("filters/regression_kalman_filter.jl")
32+
include("filters/sparse_univariate_kalman_filter.jl")
3133

3234
include("smoothers/kalman_smoother.jl")
3335

@@ -63,6 +65,7 @@ export ExperimentalSeasonalNaive
6365
export BasicStructuralExplanatory
6466
export DAR
6567
export ExponentialSmoothing
68+
export FilterOutput
6669
export LinearMultivariateTimeInvariant
6770
export LinearMultivariateTimeVariant
6871
export LinearRegression
@@ -78,6 +81,8 @@ export Optimizer
7881
export SARIMA
7982
export ScalarKalmanFilter
8083
export SeasonalNaive
84+
export SmootherOutput
85+
export SparseUnivariateKalmanFilter
8186
export StateSpaceModel
8287
export UnivariateKalmanFilter
8388
export UnobservedComponents
@@ -110,6 +115,7 @@ export isunivariate
110115
export kalman_filter
111116
export kalman_smoother
112117
export loglike
118+
export num_states
113119
export number_hyperparameters
114120
export results
115121
export set_initial_hyperparameters!
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
mutable struct SparseUnivariateKalmanState{Fl<:AbstractFloat}
2+
v::Fl
3+
F::Fl
4+
att::Vector{Fl}
5+
a::Vector{Fl}
6+
Ptt::Matrix{Fl}
7+
P::Matrix{Fl}
8+
llk::Fl
9+
steady_state::Bool
10+
# Auxiliary matrices
11+
P_to_check_steady_state::Matrix{Fl}
12+
ZP::Vector{Fl}
13+
TPtt::Matrix{Fl}
14+
function SparseUnivariateKalmanState(a1::Vector{Fl}, P1::Matrix{Fl}) where Fl
15+
m = length(a1)
16+
P_to_check_steady_state = zeros(Fl, m, m)
17+
ZP = zeros(Fl, m)
18+
TPtt = zeros(Fl, m, m)
19+
20+
return new{Fl}(
21+
zero(Fl),
22+
zero(Fl),
23+
zeros(Fl, m),
24+
a1,
25+
zeros(Fl, m, m),
26+
P1,
27+
zero(Fl),
28+
false,
29+
P_to_check_steady_state,
30+
ZP,
31+
TPtt,
32+
)
33+
end
34+
end
35+
36+
function save_a1_P1_in_filter_output!(
37+
filter_output::FilterOutput{Fl}, kalman_state::SparseUnivariateKalmanState{Fl}
38+
) where Fl
39+
filter_output.a[1] = deepcopy(kalman_state.a)
40+
filter_output.P[1] = deepcopy(kalman_state.P)
41+
return filter_output
42+
end
43+
44+
function save_kalman_state_in_filter_output!(
45+
filter_output::FilterOutput{Fl}, kalman_state::SparseUnivariateKalmanState{Fl}, t::Int
46+
) where Fl
47+
filter_output.v[t] = copy(fill(kalman_state.v, 1))
48+
filter_output.F[t] = copy(fill(kalman_state.F, 1, 1))
49+
filter_output.a[t + 1] = copy(kalman_state.a)
50+
filter_output.att[t] = copy(kalman_state.att)
51+
filter_output.P[t + 1] = copy(kalman_state.P)
52+
filter_output.Ptt[t] = copy(kalman_state.Ptt)
53+
# There is no diffuse part in SparseUnivariateKalmanState
54+
filter_output.Pinf[t] = copy(fill(zero(Fl), size(kalman_state.Ptt)))
55+
return filter_output
56+
end
57+
58+
# Univariate Kalman filter with the recursions as described
59+
# in Koopman's book TODO
60+
"""
61+
SparseUnivariateKalmanFilter{Fl <: AbstractFloat}
62+
63+
A Kalman filter that is tailored to sparse univariate systems, exploiting the fact that the
64+
dimension of the observations at any time period is 1 and that Z, T and R are sparse.
65+
66+
# TODO equations and descriptions of a1 and P1
67+
"""
68+
mutable struct SparseUnivariateKalmanFilter{Fl<:AbstractFloat} <: KalmanFilter
69+
steadystate_tol::Fl
70+
a1::Vector{Fl}
71+
P1::Matrix{Fl}
72+
skip_llk_instants::Int
73+
kalman_state::SparseUnivariateKalmanState
74+
75+
function SparseUnivariateKalmanFilter(
76+
a1::Vector{Fl},
77+
P1::Matrix{Fl},
78+
skip_llk_instants::Int=length(a1),
79+
steadystate_tol::Fl=Fl(1e-5),
80+
) where Fl
81+
kalman_state = SparseUnivariateKalmanState(copy(a1), copy(P1))
82+
return new{Fl}(steadystate_tol, a1, P1, skip_llk_instants, kalman_state)
83+
end
84+
end
85+
86+
function reset_filter!(kf::SparseUnivariateKalmanFilter{Fl}) where Fl
87+
copyto!(kf.kalman_state.a, kf.a1)
88+
copyto!(kf.kalman_state.P, kf.P1)
89+
set_state_llk_to_zero!(kf.kalman_state)
90+
fill!(kf.kalman_state.P_to_check_steady_state, zero(Fl))
91+
fill!(kf.kalman_state.ZP, zero(Fl))
92+
fill!(kf.kalman_state.TPtt, zero(Fl))
93+
kf.kalman_state.steady_state = false
94+
return kf
95+
end
96+
97+
function set_state_llk_to_zero!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl
98+
kalman_state.llk = zero(Fl)
99+
return kalman_state
100+
end
101+
102+
# TODO this should either have ! or return something
103+
function check_steady_state(kalman_state::SparseUnivariateKalmanState{Fl}, tol::Fl) where Fl
104+
@inbounds for j in axes(kalman_state.P, 2), i in axes(kalman_state.P, 1)
105+
if abs(
106+
(kalman_state.P[i, j] - kalman_state.P_to_check_steady_state[i, j]) /
107+
kalman_state.P[i, j],
108+
) > tol
109+
# Update the P_to_check_steady_state matrix
110+
copyto!(kalman_state.P_to_check_steady_state, kalman_state.P)
111+
return nothing
112+
end
113+
end
114+
kalman_state.steady_state = true
115+
return nothing
116+
end
117+
118+
function update_v!(
119+
kalman_state::SparseUnivariateKalmanState{Fl}, y::Fl, Z::SparseVector{Fl, Int}, d::Fl
120+
) where Fl
121+
kalman_state.v = y - dot(Z, kalman_state.a) - d
122+
return kalman_state
123+
end
124+
125+
function update_F!(kalman_state::SparseUnivariateKalmanState{Fl}, Z::SparseVector{Fl, Int}, H::Fl) where Fl
126+
kalman_state.ZP = kalman_state.P * Z
127+
kalman_state.F = H
128+
kalman_state.F += dot(kalman_state.ZP, Z)
129+
return kalman_state
130+
end
131+
132+
function update_att!(kalman_state::SparseUnivariateKalmanState{Fl}, Z::SparseVector{Fl, Int}) where Fl
133+
copyto!(kalman_state.att, kalman_state.a)
134+
# Here we can simplify the j iterator # TODO
135+
@inbounds for i in axes(kalman_state.P, 1), j in axes(kalman_state.P, 2)
136+
kalman_state.att[i] +=
137+
(kalman_state.v / kalman_state.F) * kalman_state.P[i, j] * Z[j]
138+
end
139+
return kalman_state
140+
end
141+
142+
function repeat_a_in_att!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl
143+
for i in eachindex(kalman_state.att)
144+
kalman_state.att[i] = kalman_state.a[i]
145+
end
146+
return kalman_state
147+
end
148+
149+
function update_a!(
150+
kalman_state::SparseUnivariateKalmanState{Fl}, T::SparseMatrixCSC{Fl, Int}, c::Vector{Fl}
151+
) where Fl
152+
kalman_state.a = T * kalman_state.att
153+
kalman_state.a .+= c
154+
return kalman_state
155+
end
156+
157+
function update_Ptt!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl
158+
LinearAlgebra.BLAS.gemm!(
159+
'N', 'T', one(Fl), kalman_state.ZP, kalman_state.ZP, zero(Fl), kalman_state.Ptt
160+
)
161+
@. kalman_state.Ptt = kalman_state.P - (kalman_state.Ptt / kalman_state.F)
162+
return kalman_state
163+
end
164+
165+
function repeat_P_in_Ptt!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl
166+
for i in axes(kalman_state.P, 1), j in axes(kalman_state.P, 2)
167+
kalman_state.Ptt[i, j] = kalman_state.P[i, j]
168+
end
169+
return kalman_state
170+
end
171+
172+
function update_P!(
173+
kalman_state::SparseUnivariateKalmanState{Fl}, T::SparseMatrixCSC{Fl, Int}, RQR::Matrix{Fl}
174+
) where Fl
175+
mul!(kalman_state.TPtt, T, kalman_state.Ptt)
176+
mul!(kalman_state.P, kalman_state.TPtt, T')
177+
kalman_state.P .+= RQR
178+
return kalman_state
179+
end
180+
181+
function update_llk!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl
182+
kalman_state.llk -= (
183+
HALF_LOG_2_PI + (log(kalman_state.F) + kalman_state.v^2 / kalman_state.F) / 2
184+
)
185+
return kalman_state
186+
end
187+
188+
function update_kalman_state!(
189+
kalman_state::SparseUnivariateKalmanState{Fl},
190+
y::Fl,
191+
Z::SparseVector{Fl, Int},
192+
T::SparseMatrixCSC{Fl, Int},
193+
H::Fl,
194+
RQR::Matrix{Fl},
195+
d::Fl,
196+
c::Vector{Fl},
197+
skip_llk_instants::Int,
198+
tol::Fl,
199+
t::Int,
200+
) where Fl
201+
if isnan(y)
202+
kalman_state.v = NaN
203+
update_F!(kalman_state, Z, H)
204+
repeat_a_in_att!(kalman_state)
205+
update_a!(kalman_state, T, c)
206+
repeat_P_in_Ptt!(kalman_state)
207+
update_P!(kalman_state, T, RQR)
208+
kalman_state.steady_state = false # Not on steadystate anymore
209+
elseif kalman_state.steady_state
210+
update_v!(kalman_state, y, Z, d)
211+
update_att!(kalman_state, Z)
212+
update_a!(kalman_state, T, c)
213+
if t > skip_llk_instants
214+
update_llk!(kalman_state)
215+
end
216+
else
217+
update_v!(kalman_state, y, Z, d)
218+
update_F!(kalman_state, Z, H)
219+
update_att!(kalman_state, Z)
220+
update_a!(kalman_state, T, c)
221+
update_Ptt!(kalman_state)
222+
update_P!(kalman_state, T, RQR)
223+
check_steady_state(kalman_state, tol)
224+
if t > skip_llk_instants
225+
update_llk!(kalman_state)
226+
end
227+
end
228+
return kalman_state
229+
end
230+
231+
function optim_kalman_filter(
232+
sys::StateSpaceSystem, filter::SparseUnivariateKalmanFilter{Fl}
233+
) where Fl
234+
return filter_recursions!(
235+
filter.kalman_state, sys, filter.steadystate_tol, filter.skip_llk_instants
236+
)
237+
end
238+
239+
function kalman_filter!(
240+
filter_output::FilterOutput, sys::StateSpaceSystem, filter::SparseUnivariateKalmanFilter{Fl}
241+
) where Fl
242+
filter_recursions!(
243+
filter_output,
244+
filter.kalman_state,
245+
sys,
246+
filter.steadystate_tol,
247+
filter.skip_llk_instants,
248+
)
249+
return filter_output
250+
end
251+
252+
function filter_recursions!(
253+
kalman_state::SparseUnivariateKalmanState{Fl},
254+
sys::LinearUnivariateTimeInvariant,
255+
steadystate_tol::Fl,
256+
skip_llk_instants::Int,
257+
) where Fl
258+
RQR = sys.R * sys.Q * sys.R'
259+
T_sparse = sparse(sys.T)
260+
Z_sparse = sparse(sys.Z)
261+
@inbounds for t in eachindex(sys.y)
262+
update_kalman_state!(
263+
kalman_state,
264+
sys.y[t],
265+
Z_sparse,
266+
T_sparse,
267+
sys.H,
268+
RQR,
269+
sys.d,
270+
sys.c,
271+
skip_llk_instants,
272+
steadystate_tol,
273+
t,
274+
)
275+
end
276+
return kalman_state.llk
277+
end
278+
279+
function filter_recursions!(
280+
filter_output::FilterOutput,
281+
kalman_state::SparseUnivariateKalmanState{Fl},
282+
sys::LinearUnivariateTimeInvariant,
283+
steadystate_tol::Fl,
284+
skip_llk_instants::Int,
285+
) where Fl
286+
RQR = sys.R * sys.Q * sys.R'
287+
T_sparse = sparse(sys.T)
288+
Z_sparse = sparse(sys.Z)
289+
save_a1_P1_in_filter_output!(filter_output, kalman_state)
290+
@inbounds for t in eachindex(sys.y)
291+
update_kalman_state!(
292+
kalman_state,
293+
sys.y[t],
294+
Z_sparse,
295+
T_sparse,
296+
sys.H,
297+
RQR,
298+
sys.d,
299+
sys.c,
300+
skip_llk_instants,
301+
steadystate_tol,
302+
t,
303+
)
304+
save_kalman_state_in_filter_output!(filter_output, kalman_state, t)
305+
end
306+
return filter_output
307+
end

0 commit comments

Comments
 (0)