Skip to content

Commit a8b048a

Browse files
authored
Allow heterogeneous distributions (#102)
* Allow heterogeneous distributions * Use Exponential * No type stability checks
1 parent 67934b1 commit a8b048a

File tree

4 files changed

+37
-11
lines changed

4 files changed

+37
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HiddenMarkovModels"
22
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
33
authors = ["Guillaume Dalle"]
4-
version = "0.5.2"
4+
version = "0.5.3"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

examples/basics.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,4 +261,3 @@ control_seq = fill(nothing, last(seq_ends)); #src
261261
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
262262
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src
263263
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
264-
test_identical_hmmbase(rng, transpose_hmm(hmm), 100; hmm_guess=transpose_hmm(hmm_guess)) #src

ext/HiddenMarkovModelsDistributionsExt.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,30 @@ using Distributions:
1010
fit
1111

1212
function HiddenMarkovModels.fit_in_sequence!(
13-
dists::AbstractVector{D}, i::Integer, x_nums::AbstractVector, w::AbstractVector
14-
) where {D<:UnivariateDistribution}
15-
return dists[i] = fit(D, x_nums, w)
13+
dists::AbstractVector{<:UnivariateDistribution},
14+
i::Integer,
15+
x_nums::AbstractVector,
16+
w::AbstractVector,
17+
)
18+
return dists[i] = fit(typeof(dists[i]), x_nums, w)
1619
end
1720

1821
function HiddenMarkovModels.fit_in_sequence!(
19-
dists::AbstractVector{D},
22+
dists::AbstractVector{<:MultivariateDistribution},
2023
i::Integer,
2124
x_vecs::AbstractVector{<:AbstractVector},
2225
w::AbstractVector,
23-
) where {D<:MultivariateDistribution}
24-
return dists[i] = fit(D, reduce(hcat, x_vecs), w)
26+
)
27+
return dists[i] = fit(typeof(dists[i]), reduce(hcat, x_vecs), w)
2528
end
2629

2730
function HiddenMarkovModels.fit_in_sequence!(
28-
dists::AbstractVector{D},
31+
dists::AbstractVector{<:MatrixDistribution},
2932
i::Integer,
3033
x_mats::AbstractVector{<:AbstractMatrix},
3134
w::AbstractVector,
32-
) where {D<:MatrixDistribution}
33-
return dists[i] = fit(D, reduce(dcat, x_mats), w)
35+
)
36+
return dists[i] = fit(typeof(dists[i]), reduce(dcat, x_mats), w)
3437
end
3538

3639
dcat(M1, M2) = cat(M1, M2; dims=3)

test/correctness.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,27 @@ end
9898
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
9999
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
100100
end
101+
102+
@testset "Normal transposed" begin # issue 99
103+
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
104+
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]
105+
106+
hmm = transpose_hmm(HMM(init, trans, dists))
107+
hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess))
108+
109+
test_identical_hmmbase(rng, hmm, T; hmm_guess)
110+
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
111+
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
112+
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
113+
end
114+
115+
@testset "Normal and Exponential" begin # issue 101
116+
dists = [Normal(μ[1][1]), Exponential(1.0)]
117+
dists_guess = [Normal(μ_guess[1][1]), Exponential(0.8)]
118+
119+
hmm = HMM(init, trans, dists)
120+
hmm_guess = HMM(init_guess, trans_guess, dists_guess)
121+
122+
test_identical_hmmbase(rng, hmm, T; hmm_guess)
123+
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
124+
end

0 commit comments

Comments
 (0)