@@ -63,6 +63,27 @@ function initialize_forward(
6363 return ForwardStorage (α, logL, B, c)
6464end
6565
66+ function _forward_digest_observation! (
67+ current_state_marginals:: AbstractVector{<:Real} ,
68+ current_obs_likelihoods:: AbstractVector{<:Real} ,
69+ hmm:: AbstractHMM ,
70+ obs,
71+ control,
72+ )
73+ a, b = current_state_marginals, current_obs_likelihoods
74+
75+ obs_logdensities! (b, hmm, obs, control)
76+ logm = maximum (b)
77+ b .= exp .(b .- logm)
78+
79+ a .*= b
80+ c = inv (sum (a))
81+ lmul! (c, a)
82+
83+ logL = - log (c) + logm
84+ return c, logL
85+ end
86+
6687function _forward! (
6788 storage:: ForwardOrForwardBackwardStorage ,
6889 hmm:: AbstractHMM ,
@@ -73,36 +94,19 @@ function _forward!(
7394)
7495 (; α, B, c, logL) = storage
7596 t1, t2 = seq_limits (seq_ends, k)
76-
77- # Initialization
78- Bₜ₁ = view (B, :, t1)
79- obs_logdensities! (Bₜ₁, hmm, obs_seq[t1], control_seq[t1])
80- logm = maximum (Bₜ₁)
81- Bₜ₁ .= exp .(Bₜ₁ .- logm)
82-
83- init = initialization (hmm)
84- αₜ₁ = view (α, :, t1)
85- αₜ₁ .= init .* Bₜ₁
86- c[t1] = inv (sum (αₜ₁))
87- lmul! (c[t1], αₜ₁)
88-
89- logL[k] = - log (c[t1]) + logm
90-
91- # Loop
92- for t in t1: (t2 - 1 )
93- Bₜ₊₁ = view (B, :, t + 1 )
94- obs_logdensities! (Bₜ₊₁, hmm, obs_seq[t + 1 ], control_seq[t + 1 ])
95- logm = maximum (Bₜ₊₁)
96- Bₜ₊₁ .= exp .(Bₜ₊₁ .- logm)
97-
98- trans = transition_matrix (hmm, control_seq[t])
99- αₜ, αₜ₊₁ = view (α, :, t), view (α, :, t + 1 )
100- mul! (αₜ₊₁, transpose (trans), αₜ)
101- αₜ₊₁ .*= Bₜ₊₁
102- c[t + 1 ] = inv (sum (αₜ₊₁))
103- lmul! (c[t + 1 ], αₜ₊₁)
104-
105- logL[k] += - log (c[t + 1 ]) + logm
97+ logL[k] = zero (eltype (logL))
98+ for t in t1: t2
99+ αₜ = view (α, :, t)
100+ Bₜ = view (B, :, t)
101+ if t == t1
102+ copyto! (αₜ, initialization (hmm))
103+ else
104+ αₜ₋₁ = view (α, :, t - 1 )
105+ predict_next_state! (αₜ, hmm, αₜ₋₁, control_seq[t - 1 ])
106+ end
107+ cₜ, logLₜ = _forward_digest_observation! (αₜ, Bₜ, hmm, obs_seq[t], control_seq[t])
108+ c[t] = cₜ
109+ logL[k] += logLₜ
106110 end
107111
108112 @argcheck isfinite (logL[k])
0 commit comments