|
| 1 | +using Test |
| 2 | +using JuliaBUGS |
| 3 | +using JuliaBUGS: @bugs, compile, @varname |
| 4 | +using JuliaBUGS.Model: |
| 5 | + _precompute_minimal_cache_keys, _marginalize_recursive, smart_copy_evaluation_env |
| 6 | + |
| 7 | +@testset "Frontier cache for HMM under different orders" begin |
| 8 | + # Simple HMM with fixed emission parameters (no continuous params) |
| 9 | + hmm_def = @bugs begin |
| 10 | + mu[1] = 0.0 |
| 11 | + mu[2] = 5.0 |
| 12 | + sigma = 1.0 |
| 13 | + |
| 14 | + trans[1, 1] = 0.7 |
| 15 | + trans[1, 2] = 0.3 |
| 16 | + trans[2, 1] = 0.4 |
| 17 | + trans[2, 2] = 0.6 |
| 18 | + |
| 19 | + pi[1] = 0.5 |
| 20 | + pi[2] = 0.5 |
| 21 | + |
| 22 | + z[1] ~ Categorical(pi[1:2]) |
| 23 | + for t in 2:T |
| 24 | + p[t, 1] = trans[z[t - 1], 1] |
| 25 | + p[t, 2] = trans[z[t - 1], 2] |
| 26 | + z[t] ~ Categorical(p[t, :]) |
| 27 | + end |
| 28 | + |
| 29 | + for t in 1:T |
| 30 | + y[t] ~ Normal(mu[z[t]], sigma) |
| 31 | + end |
| 32 | + end |
| 33 | + |
| 34 | + T = 3 |
| 35 | + data = (T=T, y=[0.1, 4.9, 5.1]) |
| 36 | + model = compile(hmm_def, data) |
| 37 | + |
| 38 | + gd = model.graph_evaluation_data |
| 39 | + n = length(gd.sorted_nodes) |
| 40 | + |
| 41 | + # Helper: index lookup for variables of interest |
| 42 | + vn = Dict( |
| 43 | + :z1 => @varname(z[1]), |
| 44 | + :z2 => @varname(z[2]), |
| 45 | + :z3 => @varname(z[3]), |
| 46 | + :y1 => @varname(y[1]), |
| 47 | + :y2 => @varname(y[2]), |
| 48 | + :y3 => @varname(y[3]), |
| 49 | + ) |
| 50 | + idx = Dict{Symbol,Int}() |
| 51 | + for (k, v) in vn |
| 52 | + i = findfirst(==(v), gd.sorted_nodes) |
| 53 | + @test i !== nothing # ensure nodes exist |
| 54 | + idx[k] = i |
| 55 | + end |
| 56 | + |
| 57 | + # Construct two evaluation orders as permutations of 1:n |
| 58 | + # Interleaved: z1, y1, z2, y2, z3, y3, then the rest |
| 59 | + priority_interleaved = [idx[:z1], idx[:y1], idx[:z2], idx[:y2], idx[:z3], idx[:y3]] |
| 60 | + rest_interleaved = [i for i in 1:n if i ∉ priority_interleaved] |
| 61 | + order_interleaved = vcat(priority_interleaved, rest_interleaved) |
| 62 | + |
| 63 | + # States-first: z1, z2, z3, y1, y2, y3, then the rest |
| 64 | + priority_states_first = [idx[:z1], idx[:z2], idx[:z3], idx[:y1], idx[:y2], idx[:y3]] |
| 65 | + rest_states_first = [i for i in 1:n if i ∉ priority_states_first] |
| 66 | + order_states_first = vcat(priority_states_first, rest_states_first) |
| 67 | + |
| 68 | + # Precompute minimal keys for both orders |
| 69 | + keys_interleaved = _precompute_minimal_cache_keys(model, order_interleaved) |
| 70 | + keys_states_first = _precompute_minimal_cache_keys(model, order_states_first) |
| 71 | + |
| 72 | + # Helper to map frontier indices back to a set of variable symbols we care about |
| 73 | + function frontier_syms(keys, key_idx) |
| 74 | + frontier = get(keys, key_idx, Int[]) |
| 75 | + syms = Set{Symbol}() |
| 76 | + for (name, i) in idx |
| 77 | + if i in frontier |
| 78 | + push!(syms, name) |
| 79 | + end |
| 80 | + end |
| 81 | + return syms |
| 82 | + end |
| 83 | + |
| 84 | + # Interleaved expectations: frontier size stays 1; y[t] depends on z[t] |
| 85 | + @test frontier_syms(keys_interleaved, idx[:z1]) == Set{Symbol}() |
| 86 | + @test frontier_syms(keys_interleaved, idx[:y1]) == Set([:z1]) |
| 87 | + @test frontier_syms(keys_interleaved, idx[:z2]) == Set([:z1]) |
| 88 | + @test frontier_syms(keys_interleaved, idx[:y2]) == Set([:z2]) |
| 89 | + @test frontier_syms(keys_interleaved, idx[:z3]) == Set([:z2]) |
| 90 | + @test frontier_syms(keys_interleaved, idx[:y3]) == Set([:z3]) |
| 91 | + |
| 92 | + # States-first expectations: frontier grows across z's, peaks at y1 |
| 93 | + @test frontier_syms(keys_states_first, idx[:z1]) == Set{Symbol}() |
| 94 | + @test frontier_syms(keys_states_first, idx[:z2]) == Set([:z1]) |
| 95 | + @test frontier_syms(keys_states_first, idx[:z3]) == Set([:z1, :z2]) |
| 96 | + @test frontier_syms(keys_states_first, idx[:y1]) == Set([:z1, :z2, :z3]) |
| 97 | + @test frontier_syms(keys_states_first, idx[:y2]) == Set([:z2, :z3]) |
| 98 | + @test frontier_syms(keys_states_first, idx[:y3]) == Set([:z3]) |
| 99 | + |
| 100 | + # Sanity: different orders should not change marginalized log-density |
| 101 | + env = smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols) |
| 102 | + params = Float64[] |
| 103 | + memo1 = Dict{Tuple{Int,Int,UInt64},Any}() |
| 104 | + logp1 = _marginalize_recursive( |
| 105 | + model, env, order_interleaved, params, 1, Dict{Any,Int}(), memo1, keys_interleaved |
| 106 | + ) |
| 107 | + |
| 108 | + env2 = smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols) |
| 109 | + memo2 = Dict{Tuple{Int,Int,UInt64},Any}() |
| 110 | + logp2 = _marginalize_recursive( |
| 111 | + model, |
| 112 | + env2, |
| 113 | + order_states_first, |
| 114 | + params, |
| 115 | + 1, |
| 116 | + Dict{Any,Int}(), |
| 117 | + memo2, |
| 118 | + keys_states_first, |
| 119 | + ) |
| 120 | + |
| 121 | + @test isapprox(logp1, logp2; atol=1e-10) |
| 122 | + |
| 123 | + # And states-first should lead to equal or larger memo usage (worse frontier) |
| 124 | + @test length(memo2) >= length(memo1) |
| 125 | +end |
0 commit comments