Skip to content

Commit 67934b1

Browse files
authored
Allow different types for elementwise log (#100)
* Allow different type for elementwise log * Bump * Fix kwarg * Fix init * Fix test and CI
1 parent 6bfb23a commit 67934b1

File tree

6 files changed

+35
-15
lines changed

6 files changed

+35
-15
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,23 @@ concurrency:
1212
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
1313
jobs:
1414
test:
15-
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
16-
runs-on: ${{ matrix.os }}
15+
name: Julia ${{ matrix.version }} - ${{ github.event_name }}
16+
runs-on: ubuntu-latest
1717
strategy:
1818
fail-fast: false
1919
matrix:
2020
version:
21+
- '1.9'
2122
- '1'
2223
os:
2324
- ubuntu-latest
24-
arch:
25-
- x64
2625
steps:
27-
- uses: actions/checkout@v2
28-
- uses: julia-actions/setup-julia@v1
26+
- uses: actions/checkout@v4
27+
- uses: julia-actions/setup-julia@v2
2928
with:
3029
version: ${{ matrix.version }}
31-
arch: ${{ matrix.arch }}
32-
- uses: julia-actions/cache@v1
30+
arch: x64
31+
- uses: julia-actions/cache@v2
3332
- uses: julia-actions/julia-buildpkg@v1
3433
- uses: julia-actions/julia-runtest@v1
3534
- uses: julia-actions/julia-processcoverage@v1

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HiddenMarkovModels"
22
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
33
authors = ["Guillaume Dalle"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

examples/basics.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ This is important to keep in mind when testing new models.
189189
In many applications, we have access to various observation sequences of different lengths.
190190
=#
191191

192-
nb_seqs = 100
192+
nb_seqs = 300
193193
long_obs_seqs = [last(rand(rng, hmm, rand(rng, 100:200))) for k in 1:nb_seqs];
194194
typeof(long_obs_seqs)
195195

@@ -261,3 +261,4 @@ control_seq = fill(nothing, last(seq_ends)); #src
261261
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
262262
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src
263263
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
264+
test_identical_hmmbase(rng, transpose_hmm(hmm), 100; hmm_guess=transpose_hmm(hmm_guess)) #src

libs/HMMTest/src/HMMTest.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ using Random: AbstractRNG
99
using Statistics: mean
1010
using Test: @test, @testset, @test_broken
1111

12+
export transpose_hmm
1213
export test_equal_hmms, test_coherent_algorithms
1314
export test_identical_hmmbase
1415
export test_allocations
1516
export test_type_stability
1617

18+
include("utils.jl")
1719
include("coherence.jl")
1820
include("allocations.jl")
1921
include("hmmbase.jl")

libs/HMMTest/src/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
function transpose_hmm(hmm::HMM)
2+
init = initialization(hmm)
3+
trans = transition_matrix(hmm)
4+
dists = obs_distributions(hmm)
5+
trans_transpose = transpose(convert(typeof(trans), transpose(trans)))
6+
@assert trans_transpose == trans
7+
return HMM(init, trans_transpose, dists)
8+
end

src/types/hmm.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,31 @@ Basic implementation of an HMM.
77
88
$(TYPEDFIELDS)
99
"""
10-
struct HMM{V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM
10+
struct HMM{
11+
V<:AbstractVector,
12+
M<:AbstractMatrix,
13+
VD<:AbstractVector,
14+
Vl<:AbstractVector,
15+
Ml<:AbstractMatrix,
16+
} <: AbstractHMM
1117
"initial state probabilities"
1218
init::V
1319
"state transition probabilities"
1420
trans::M
1521
"observation distributions"
1622
dists::VD
1723
"logarithms of initial state probabilities"
18-
loginit::V
24+
loginit::Vl
1925
"logarithms of state transition probabilities"
20-
logtrans::M
26+
logtrans::Ml
2127

2228
function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector)
23-
hmm = new{typeof(init),typeof(trans),typeof(dists)}(
24-
init, trans, dists, elementwise_log(init), elementwise_log(trans)
29+
log_init = elementwise_log(init)
30+
log_trans = elementwise_log(trans)
31+
hmm = new{
32+
typeof(init),typeof(trans),typeof(dists),typeof(log_init),typeof(log_trans)
33+
}(
34+
init, trans, dists, log_init, log_trans
2535
)
2636
@argcheck valid_hmm(hmm)
2737
return hmm

0 commit comments

Comments
 (0)