Skip to content

Commit 9802bad

Browse files
authored
Disable multithreading when seq_ends is passed as a tuple (#113)
* Fix loglikelihood increase check in Baum-Welch * Disable multithreading when `seq_ends` is given as a tuple * Remove seq_ends typing in examples
1 parent 3276e3c commit 9802bad

File tree

16 files changed

+93
-64
lines changed

16 files changed

+93
-64
lines changed

examples/controlled.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function StatsAPI.fit!(
9494
fb_storage::HMMs.ForwardBackwardStorage,
9595
obs_seq::AbstractVector,
9696
control_seq::AbstractVector;
97-
seq_ends::AbstractVector{Int},
97+
seq_ends,
9898
) where {T}
9999
(; γ, ξ) = fb_storage
100100
N = length(hmm)

examples/interfaces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ function StatsAPI.fit!(
186186
hmm::PriorHMM,
187187
fb_storage::HiddenMarkovModels.ForwardBackwardStorage,
188188
obs_seq::AbstractVector;
189-
seq_ends::AbstractVector{Int},
189+
seq_ends,
190190
)
191191
## initialize to defaults without observations
192192
hmm.init .= 0

examples/temporal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function StatsAPI.fit!(
109109
fb_storage::HMMs.ForwardBackwardStorage,
110110
obs_seq::AbstractVector,
111111
control_seq::AbstractVector;
112-
seq_ends::AbstractVector{Int},
112+
seq_ends,
113113
) where {T}
114114
(; γ, ξ) = fb_storage
115115
L, N = period(hmm), length(hmm)

libs/HMMTest/src/HMMTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module HMMTest
22

33
using BenchmarkTools: @ballocated
44
using HiddenMarkovModels
5+
using HiddenMarkovModels: AbstractVectorOrNTuple
56
import HiddenMarkovModels as HMMs
67
using HMMBase: HMMBase
78
using JET: @test_opt, @test_call

libs/HMMTest/src/allocations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ function test_allocations(
33
rng::AbstractRNG,
44
hmm::AbstractHMM,
55
control_seq::AbstractVector;
6-
seq_ends::AbstractVector{Int},
6+
seq_ends::AbstractVectorOrNTuple{Int},
77
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
88
)
99
@testset "Allocations" begin

libs/HMMTest/src/coherence.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function test_coherent_algorithms(
5555
rng::AbstractRNG,
5656
hmm::AbstractHMM,
5757
control_seq::AbstractVector;
58-
seq_ends::AbstractVector{Int},
58+
seq_ends::AbstractVectorOrNTuple{Int},
5959
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
6060
atol::Real=0.05,
6161
init::Bool=true,

libs/HMMTest/src/jet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ function test_type_stability(
33
rng::AbstractRNG,
44
hmm::AbstractHMM,
55
control_seq::AbstractVector;
6-
seq_ends::AbstractVector{Int},
6+
seq_ends::AbstractVectorOrNTuple{Int},
77
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
88
)
99
@testset "Type stability" begin

src/inference/baum_welch.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function baum_welch!(
2222
hmm::AbstractHMM,
2323
obs_seq::AbstractVector,
2424
control_seq::AbstractVector;
25-
seq_ends::AbstractVector{Int},
25+
seq_ends::AbstractVectorOrNTuple{Int},
2626
atol::Real,
2727
max_iterations::Integer,
2828
loglikelihood_increasing::Bool,
@@ -55,7 +55,7 @@ function baum_welch(
5555
hmm_guess::AbstractHMM,
5656
obs_seq::AbstractVector,
5757
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
58-
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
58+
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
5959
atol=1e-5,
6060
max_iterations=100,
6161
loglikelihood_increasing=true,
@@ -85,7 +85,7 @@ function StatsAPI.fit!(
8585
fb_storage::ForwardBackwardStorage,
8686
obs_seq::AbstractVector,
8787
control_seq::AbstractVector;
88-
seq_ends::AbstractVector{Int},
88+
seq_ends::AbstractVectorOrNTuple{Int},
8989
)
9090
return fit!(hmm, fb_storage, obs_seq; seq_ends)
9191
end

src/inference/chainrules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function _params_and_loglikelihoods(
44
hmm::AbstractHMM,
55
obs_seq::Vector,
66
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
7-
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
7+
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
88
)
99
init = initialization(hmm)
1010
trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t
@@ -22,7 +22,7 @@ function ChainRulesCore.rrule(
2222
hmm::AbstractHMM,
2323
obs_seq::AbstractVector,
2424
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
25-
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
25+
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
2626
)
2727
_, pullback = rrule_via_ad(
2828
rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends

src/inference/forward.jl

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,35 @@ struct ForwardStorage{R}
1616
c::Vector{R}
1717
end
1818

19+
"""
20+
$(TYPEDEF)
21+
22+
# Fields
23+
24+
Only the fields with a description are part of the public API.
25+
26+
$(TYPEDFIELDS)
27+
"""
28+
struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}}
29+
"posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`"
30+
γ::Matrix{R}
31+
"posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`"
32+
ξ::Vector{M}
33+
"one loglikelihood per observation sequence"
34+
logL::Vector{R}
35+
B::Matrix{R}
36+
α::Matrix{R}
37+
c::Vector{R}
38+
β::Matrix{R}
39+
::Matrix{R}
40+
end
41+
1942
Base.eltype(::ForwardStorage{R}) where {R} = R
43+
Base.eltype(::ForwardBackwardStorage{R}) where {R} = R
44+
45+
const ForwardOrForwardBackwardStorage{R} = Union{
46+
ForwardStorage{R},ForwardBackwardStorage{R}
47+
}
2048

2149
"""
2250
$(SIGNATURES)
@@ -25,7 +53,7 @@ function initialize_forward(
2553
hmm::AbstractHMM,
2654
obs_seq::AbstractVector,
2755
control_seq::AbstractVector;
28-
seq_ends::AbstractVector{Int},
56+
seq_ends::AbstractVectorOrNTuple{Int},
2957
)
3058
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
3159
R = eltype(hmm, obs_seq[1], control_seq[1])
@@ -40,7 +68,7 @@ end
4068
$(SIGNATURES)
4169
"""
4270
function forward!(
43-
storage,
71+
storage::ForwardOrForwardBackwardStorage,
4472
hmm::AbstractHMM,
4573
obs_seq::AbstractVector,
4674
control_seq::AbstractVector,
@@ -88,16 +116,23 @@ end
88116
$(SIGNATURES)
89117
"""
90118
function forward!(
91-
storage,
119+
storage::ForwardOrForwardBackwardStorage,
92120
hmm::AbstractHMM,
93121
obs_seq::AbstractVector,
94122
control_seq::AbstractVector;
95-
seq_ends::AbstractVector{Int},
123+
seq_ends::AbstractVectorOrNTuple{Int},
96124
)
97125
(; logL) = storage
98-
@threads for k in eachindex(seq_ends)
99-
t1, t2 = seq_limits(seq_ends, k)
100-
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
126+
if seq_ends isa NTuple
127+
for k in eachindex(seq_ends)
128+
t1, t2 = seq_limits(seq_ends, k)
129+
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
130+
end
131+
else
132+
@threads for k in eachindex(seq_ends)
133+
t1, t2 = seq_limits(seq_ends, k)
134+
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
135+
end
101136
end
102137
return nothing
103138
end
@@ -113,7 +148,7 @@ function forward(
113148
hmm::AbstractHMM,
114149
obs_seq::AbstractVector,
115150
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
116-
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
151+
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
117152
)
118153
storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends)
119154
forward!(storage, hmm, obs_seq, control_seq; seq_ends)

0 commit comments

Comments
 (0)