@@ -16,7 +16,35 @@ struct ForwardStorage{R}
1616 c:: Vector{R}
1717end
1818
19+ """
20+ $(TYPEDEF)
21+
22+ # Fields
23+
24+ Only the fields with a description are part of the public API.
25+
26+ $(TYPEDFIELDS)
27+ """
28+ struct ForwardBackwardStorage{R,M<: AbstractMatrix{R} }
29+ " posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`"
30+ γ:: Matrix{R}
31+ " posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`"
32+ ξ:: Vector{M}
33+ " one loglikelihood per observation sequence"
34+ logL:: Vector{R}
35+ B:: Matrix{R}
36+ α:: Matrix{R}
37+ c:: Vector{R}
38+ β:: Matrix{R}
39+ Bβ:: Matrix{R}
40+ end
41+
1942Base. eltype (:: ForwardStorage{R} ) where {R} = R
43+ Base. eltype (:: ForwardBackwardStorage{R} ) where {R} = R
44+
45+ const ForwardOrForwardBackwardStorage{R} = Union{
46+ ForwardStorage{R},ForwardBackwardStorage{R}
47+ }
2048
2149"""
2250$(SIGNATURES)
@@ -25,7 +53,7 @@ function initialize_forward(
2553 hmm:: AbstractHMM ,
2654 obs_seq:: AbstractVector ,
2755 control_seq:: AbstractVector ;
28- seq_ends:: AbstractVector {Int} ,
56+ seq_ends:: AbstractVectorOrNTuple {Int} ,
2957)
3058 N, T, K = length (hmm), length (obs_seq), length (seq_ends)
3159 R = eltype (hmm, obs_seq[1 ], control_seq[1 ])
4068$(SIGNATURES)
4169"""
4270function forward! (
43- storage,
71+ storage:: ForwardOrForwardBackwardStorage ,
4472 hmm:: AbstractHMM ,
4573 obs_seq:: AbstractVector ,
4674 control_seq:: AbstractVector ,
88116$(SIGNATURES)
89117"""
90118function forward! (
91- storage,
119+ storage:: ForwardOrForwardBackwardStorage ,
92120 hmm:: AbstractHMM ,
93121 obs_seq:: AbstractVector ,
94122 control_seq:: AbstractVector ;
95- seq_ends:: AbstractVector {Int} ,
123+ seq_ends:: AbstractVectorOrNTuple {Int} ,
96124)
97125 (; logL) = storage
98- @threads for k in eachindex (seq_ends)
99- t1, t2 = seq_limits (seq_ends, k)
100- logL[k] = forward! (storage, hmm, obs_seq, control_seq, t1, t2;)
126+ if seq_ends isa NTuple
127+ for k in eachindex (seq_ends)
128+ t1, t2 = seq_limits (seq_ends, k)
129+ logL[k] = forward! (storage, hmm, obs_seq, control_seq, t1, t2;)
130+ end
131+ else
132+ @threads for k in eachindex (seq_ends)
133+ t1, t2 = seq_limits (seq_ends, k)
134+ logL[k] = forward! (storage, hmm, obs_seq, control_seq, t1, t2;)
135+ end
101136 end
102137 return nothing
103138end
@@ -113,7 +148,7 @@ function forward(
113148 hmm:: AbstractHMM ,
114149 obs_seq:: AbstractVector ,
115150 control_seq:: AbstractVector = Fill (nothing , length (obs_seq));
116- seq_ends:: AbstractVector {Int}= Fill (length (obs_seq), 1 ),
151+ seq_ends:: AbstractVectorOrNTuple {Int}= (length (obs_seq),),
117152)
118153 storage = initialize_forward (hmm, obs_seq, control_seq; seq_ends)
119154 forward! (storage, hmm, obs_seq, control_seq; seq_ends)
0 commit comments