|
| 1 | +mutable struct SparseUnivariateKalmanState{Fl<:AbstractFloat} |
| 2 | + v::Fl |
| 3 | + F::Fl |
| 4 | + att::Vector{Fl} |
| 5 | + a::Vector{Fl} |
| 6 | + Ptt::Matrix{Fl} |
| 7 | + P::Matrix{Fl} |
| 8 | + llk::Fl |
| 9 | + steady_state::Bool |
| 10 | + # Auxiliary matrices |
| 11 | + P_to_check_steady_state::Matrix{Fl} |
| 12 | + ZP::Vector{Fl} |
| 13 | + TPtt::Matrix{Fl} |
| 14 | + function SparseUnivariateKalmanState(a1::Vector{Fl}, P1::Matrix{Fl}) where Fl |
| 15 | + m = length(a1) |
| 16 | + P_to_check_steady_state = zeros(Fl, m, m) |
| 17 | + ZP = zeros(Fl, m) |
| 18 | + TPtt = zeros(Fl, m, m) |
| 19 | + |
| 20 | + return new{Fl}( |
| 21 | + zero(Fl), |
| 22 | + zero(Fl), |
| 23 | + zeros(Fl, m), |
| 24 | + a1, |
| 25 | + zeros(Fl, m, m), |
| 26 | + P1, |
| 27 | + zero(Fl), |
| 28 | + false, |
| 29 | + P_to_check_steady_state, |
| 30 | + ZP, |
| 31 | + TPtt, |
| 32 | + ) |
| 33 | + end |
| 34 | +end |
| 35 | + |
| 36 | +function save_a1_P1_in_filter_output!( |
| 37 | + filter_output::FilterOutput{Fl}, kalman_state::SparseUnivariateKalmanState{Fl} |
| 38 | +) where Fl |
| 39 | + filter_output.a[1] = deepcopy(kalman_state.a) |
| 40 | + filter_output.P[1] = deepcopy(kalman_state.P) |
| 41 | + return filter_output |
| 42 | +end |
| 43 | + |
| 44 | +function save_kalman_state_in_filter_output!( |
| 45 | + filter_output::FilterOutput{Fl}, kalman_state::SparseUnivariateKalmanState{Fl}, t::Int |
| 46 | +) where Fl |
| 47 | + filter_output.v[t] = copy(fill(kalman_state.v, 1)) |
| 48 | + filter_output.F[t] = copy(fill(kalman_state.F, 1, 1)) |
| 49 | + filter_output.a[t + 1] = copy(kalman_state.a) |
| 50 | + filter_output.att[t] = copy(kalman_state.att) |
| 51 | + filter_output.P[t + 1] = copy(kalman_state.P) |
| 52 | + filter_output.Ptt[t] = copy(kalman_state.Ptt) |
| 53 | + # There is no diffuse part in SparseUnivariateKalmanState |
| 54 | + filter_output.Pinf[t] = copy(fill(zero(Fl), size(kalman_state.Ptt))) |
| 55 | + return filter_output |
| 56 | +end |
| 57 | + |
| 58 | +# Univariate Kalman filter with the recursions as described |
| 59 | +# in Koopman's book TODO |
| 60 | +""" |
| 61 | + SparseUnivariateKalmanFilter{Fl <: AbstractFloat} |
| 62 | +
|
| 63 | +A Kalman filter that is tailored to sparse univariate systems, exploiting the fact that the |
| 64 | +dimension of the observations at any time period is 1 and that Z, T and R are sparse. |
| 65 | +
|
| 66 | +# TODO equations and descriptions of a1 and P1 |
| 67 | +""" |
| 68 | +mutable struct SparseUnivariateKalmanFilter{Fl<:AbstractFloat} <: KalmanFilter |
| 69 | + steadystate_tol::Fl |
| 70 | + a1::Vector{Fl} |
| 71 | + P1::Matrix{Fl} |
| 72 | + skip_llk_instants::Int |
| 73 | + kalman_state::SparseUnivariateKalmanState |
| 74 | + |
| 75 | + function SparseUnivariateKalmanFilter( |
| 76 | + a1::Vector{Fl}, |
| 77 | + P1::Matrix{Fl}, |
| 78 | + skip_llk_instants::Int=length(a1), |
| 79 | + steadystate_tol::Fl=Fl(1e-5), |
| 80 | + ) where Fl |
| 81 | + kalman_state = SparseUnivariateKalmanState(copy(a1), copy(P1)) |
| 82 | + return new{Fl}(steadystate_tol, a1, P1, skip_llk_instants, kalman_state) |
| 83 | + end |
| 84 | +end |
| 85 | + |
| 86 | +function reset_filter!(kf::SparseUnivariateKalmanFilter{Fl}) where Fl |
| 87 | + copyto!(kf.kalman_state.a, kf.a1) |
| 88 | + copyto!(kf.kalman_state.P, kf.P1) |
| 89 | + set_state_llk_to_zero!(kf.kalman_state) |
| 90 | + fill!(kf.kalman_state.P_to_check_steady_state, zero(Fl)) |
| 91 | + fill!(kf.kalman_state.ZP, zero(Fl)) |
| 92 | + fill!(kf.kalman_state.TPtt, zero(Fl)) |
| 93 | + kf.kalman_state.steady_state = false |
| 94 | + return kf |
| 95 | +end |
| 96 | + |
| 97 | +function set_state_llk_to_zero!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl |
| 98 | + kalman_state.llk = zero(Fl) |
| 99 | + return kalman_state |
| 100 | +end |
| 101 | + |
| 102 | +# TODO this should either have ! or return something |
| 103 | +function check_steady_state(kalman_state::SparseUnivariateKalmanState{Fl}, tol::Fl) where Fl |
| 104 | + @inbounds for j in axes(kalman_state.P, 2), i in axes(kalman_state.P, 1) |
| 105 | + if abs( |
| 106 | + (kalman_state.P[i, j] - kalman_state.P_to_check_steady_state[i, j]) / |
| 107 | + kalman_state.P[i, j], |
| 108 | + ) > tol |
| 109 | + # Update the P_to_check_steady_state matrix |
| 110 | + copyto!(kalman_state.P_to_check_steady_state, kalman_state.P) |
| 111 | + return nothing |
| 112 | + end |
| 113 | + end |
| 114 | + kalman_state.steady_state = true |
| 115 | + return nothing |
| 116 | +end |
| 117 | + |
| 118 | +function update_v!( |
| 119 | + kalman_state::SparseUnivariateKalmanState{Fl}, y::Fl, Z::SparseVector{Fl, Int}, d::Fl |
| 120 | +) where Fl |
| 121 | + kalman_state.v = y - dot(Z, kalman_state.a) - d |
| 122 | + return kalman_state |
| 123 | +end |
| 124 | + |
| 125 | +function update_F!(kalman_state::SparseUnivariateKalmanState{Fl}, Z::SparseVector{Fl, Int}, H::Fl) where Fl |
| 126 | + kalman_state.ZP = kalman_state.P * Z |
| 127 | + kalman_state.F = H |
| 128 | + kalman_state.F += dot(kalman_state.ZP, Z) |
| 129 | + return kalman_state |
| 130 | +end |
| 131 | + |
| 132 | +function update_att!(kalman_state::SparseUnivariateKalmanState{Fl}, Z::SparseVector{Fl, Int}) where Fl |
| 133 | + copyto!(kalman_state.att, kalman_state.a) |
| 134 | + # Here we can simplify the j iterator # TODO |
| 135 | + @inbounds for i in axes(kalman_state.P, 1), j in axes(kalman_state.P, 2) |
| 136 | + kalman_state.att[i] += |
| 137 | + (kalman_state.v / kalman_state.F) * kalman_state.P[i, j] * Z[j] |
| 138 | + end |
| 139 | + return kalman_state |
| 140 | +end |
| 141 | + |
| 142 | +function repeat_a_in_att!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl |
| 143 | + for i in eachindex(kalman_state.att) |
| 144 | + kalman_state.att[i] = kalman_state.a[i] |
| 145 | + end |
| 146 | + return kalman_state |
| 147 | +end |
| 148 | + |
| 149 | +function update_a!( |
| 150 | + kalman_state::SparseUnivariateKalmanState{Fl}, T::SparseMatrixCSC{Fl, Int}, c::Vector{Fl} |
| 151 | +) where Fl |
| 152 | + kalman_state.a = T * kalman_state.att |
| 153 | + kalman_state.a .+= c |
| 154 | + return kalman_state |
| 155 | +end |
| 156 | + |
| 157 | +function update_Ptt!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl |
| 158 | + LinearAlgebra.BLAS.gemm!( |
| 159 | + 'N', 'T', one(Fl), kalman_state.ZP, kalman_state.ZP, zero(Fl), kalman_state.Ptt |
| 160 | + ) |
| 161 | + @. kalman_state.Ptt = kalman_state.P - (kalman_state.Ptt / kalman_state.F) |
| 162 | + return kalman_state |
| 163 | +end |
| 164 | + |
| 165 | +function repeat_P_in_Ptt!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl |
| 166 | + for i in axes(kalman_state.P, 1), j in axes(kalman_state.P, 2) |
| 167 | + kalman_state.Ptt[i, j] = kalman_state.P[i, j] |
| 168 | + end |
| 169 | + return kalman_state |
| 170 | +end |
| 171 | + |
| 172 | +function update_P!( |
| 173 | + kalman_state::SparseUnivariateKalmanState{Fl}, T::SparseMatrixCSC{Fl, Int}, RQR::Matrix{Fl} |
| 174 | +) where Fl |
| 175 | + mul!(kalman_state.TPtt, T, kalman_state.Ptt) |
| 176 | + mul!(kalman_state.P, kalman_state.TPtt, T') |
| 177 | + kalman_state.P .+= RQR |
| 178 | + return kalman_state |
| 179 | +end |
| 180 | + |
| 181 | +function update_llk!(kalman_state::SparseUnivariateKalmanState{Fl}) where Fl |
| 182 | + kalman_state.llk -= ( |
| 183 | + HALF_LOG_2_PI + (log(kalman_state.F) + kalman_state.v^2 / kalman_state.F) / 2 |
| 184 | + ) |
| 185 | + return kalman_state |
| 186 | +end |
| 187 | + |
| 188 | +function update_kalman_state!( |
| 189 | + kalman_state::SparseUnivariateKalmanState{Fl}, |
| 190 | + y::Fl, |
| 191 | + Z::SparseVector{Fl, Int}, |
| 192 | + T::SparseMatrixCSC{Fl, Int}, |
| 193 | + H::Fl, |
| 194 | + RQR::Matrix{Fl}, |
| 195 | + d::Fl, |
| 196 | + c::Vector{Fl}, |
| 197 | + skip_llk_instants::Int, |
| 198 | + tol::Fl, |
| 199 | + t::Int, |
| 200 | +) where Fl |
| 201 | + if isnan(y) |
| 202 | + kalman_state.v = NaN |
| 203 | + update_F!(kalman_state, Z, H) |
| 204 | + repeat_a_in_att!(kalman_state) |
| 205 | + update_a!(kalman_state, T, c) |
| 206 | + repeat_P_in_Ptt!(kalman_state) |
| 207 | + update_P!(kalman_state, T, RQR) |
| 208 | + kalman_state.steady_state = false # Not on steadystate anymore |
| 209 | + elseif kalman_state.steady_state |
| 210 | + update_v!(kalman_state, y, Z, d) |
| 211 | + update_att!(kalman_state, Z) |
| 212 | + update_a!(kalman_state, T, c) |
| 213 | + if t > skip_llk_instants |
| 214 | + update_llk!(kalman_state) |
| 215 | + end |
| 216 | + else |
| 217 | + update_v!(kalman_state, y, Z, d) |
| 218 | + update_F!(kalman_state, Z, H) |
| 219 | + update_att!(kalman_state, Z) |
| 220 | + update_a!(kalman_state, T, c) |
| 221 | + update_Ptt!(kalman_state) |
| 222 | + update_P!(kalman_state, T, RQR) |
| 223 | + check_steady_state(kalman_state, tol) |
| 224 | + if t > skip_llk_instants |
| 225 | + update_llk!(kalman_state) |
| 226 | + end |
| 227 | + end |
| 228 | + return kalman_state |
| 229 | +end |
| 230 | + |
| 231 | +function optim_kalman_filter( |
| 232 | + sys::StateSpaceSystem, filter::SparseUnivariateKalmanFilter{Fl} |
| 233 | +) where Fl |
| 234 | + return filter_recursions!( |
| 235 | + filter.kalman_state, sys, filter.steadystate_tol, filter.skip_llk_instants |
| 236 | + ) |
| 237 | +end |
| 238 | + |
| 239 | +function kalman_filter!( |
| 240 | + filter_output::FilterOutput, sys::StateSpaceSystem, filter::SparseUnivariateKalmanFilter{Fl} |
| 241 | +) where Fl |
| 242 | + filter_recursions!( |
| 243 | + filter_output, |
| 244 | + filter.kalman_state, |
| 245 | + sys, |
| 246 | + filter.steadystate_tol, |
| 247 | + filter.skip_llk_instants, |
| 248 | + ) |
| 249 | + return filter_output |
| 250 | +end |
| 251 | + |
| 252 | +function filter_recursions!( |
| 253 | + kalman_state::SparseUnivariateKalmanState{Fl}, |
| 254 | + sys::LinearUnivariateTimeInvariant, |
| 255 | + steadystate_tol::Fl, |
| 256 | + skip_llk_instants::Int, |
| 257 | +) where Fl |
| 258 | + RQR = sys.R * sys.Q * sys.R' |
| 259 | + T_sparse = sparse(sys.T) |
| 260 | + Z_sparse = sparse(sys.Z) |
| 261 | + @inbounds for t in eachindex(sys.y) |
| 262 | + update_kalman_state!( |
| 263 | + kalman_state, |
| 264 | + sys.y[t], |
| 265 | + Z_sparse, |
| 266 | + T_sparse, |
| 267 | + sys.H, |
| 268 | + RQR, |
| 269 | + sys.d, |
| 270 | + sys.c, |
| 271 | + skip_llk_instants, |
| 272 | + steadystate_tol, |
| 273 | + t, |
| 274 | + ) |
| 275 | + end |
| 276 | + return kalman_state.llk |
| 277 | +end |
| 278 | + |
| 279 | +function filter_recursions!( |
| 280 | + filter_output::FilterOutput, |
| 281 | + kalman_state::SparseUnivariateKalmanState{Fl}, |
| 282 | + sys::LinearUnivariateTimeInvariant, |
| 283 | + steadystate_tol::Fl, |
| 284 | + skip_llk_instants::Int, |
| 285 | +) where Fl |
| 286 | + RQR = sys.R * sys.Q * sys.R' |
| 287 | + T_sparse = sparse(sys.T) |
| 288 | + Z_sparse = sparse(sys.Z) |
| 289 | + save_a1_P1_in_filter_output!(filter_output, kalman_state) |
| 290 | + @inbounds for t in eachindex(sys.y) |
| 291 | + update_kalman_state!( |
| 292 | + kalman_state, |
| 293 | + sys.y[t], |
| 294 | + Z_sparse, |
| 295 | + T_sparse, |
| 296 | + sys.H, |
| 297 | + RQR, |
| 298 | + sys.d, |
| 299 | + sys.c, |
| 300 | + skip_llk_instants, |
| 301 | + steadystate_tol, |
| 302 | + t, |
| 303 | + ) |
| 304 | + save_kalman_state_in_filter_output!(filter_output, kalman_state, t) |
| 305 | + end |
| 306 | + return filter_output |
| 307 | +end |
0 commit comments