Skip to content

Commit 2bc9e4b

Browse files
committed
add hmm sanity check
1 parent 7243c23 commit 2bc9e4b

File tree

4 files changed

+147
-11
lines changed

4 files changed

+147
-11
lines changed

JuliaBUGS/src/model/bugsmodel.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,12 @@ function BUGSModel(
232232
n = length(gd.sorted_nodes)
233233
JuliaBUGS.Model._precompute_minimal_cache_keys(m, collect(1:n))
234234
end
235-
# Attach minimal keys to GraphEvaluationData
235+
# Attach minimal order and keys to GraphEvaluationData
236+
order = if isempty(gd.marginalization_order)
237+
collect(1:length(gd.sorted_nodes))
238+
else
239+
gd.marginalization_order
240+
end
236241
gd2 = GraphEvaluationData(
237242
gd.sorted_nodes,
238243
gd.sorted_parameters,
@@ -242,6 +247,7 @@ function BUGSModel(
242247
gd.loop_vars_vals,
243248
gd.node_types,
244249
gd.is_discrete_finite_vals,
250+
order,
245251
minimal_keys,
246252
)
247253
# Return final model with cached minimal keys

JuliaBUGS/src/model/evaluation.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ function _precompute_minimal_cache_keys(model::BUGSModel, order::Vector{Int})
324324
if is_discrete_finite[p_label] && !is_observed[p_label]
325325
# Default to the position of the variable itself if unseen
326326
default_pos = order_pos[p_label]
327-
last_use_pos[p_label] = max(get(last_use_pos, p_label, default_pos), j_pos)
327+
last_use_pos[p_label] = max(
328+
get(last_use_pos, p_label, default_pos), j_pos
329+
)
328330
end
329331
end
330332
end
@@ -656,16 +658,18 @@ function evaluate_with_marginalization_values!!(
656658
)
657659
end
658660

659-
# Get indices for evaluation order
660-
n = length(model.graph_evaluation_data.sorted_nodes)
661+
# Get indices for evaluation order (default to stored marginalization order)
662+
gd = model.graph_evaluation_data
663+
n = length(gd.sorted_nodes)
661664
sorted_indices = collect(1:n)
662665

663666
# Use precomputed minimal cache keys if available; otherwise compute once for this call
664-
minimal_keys = if !isempty(model.graph_evaluation_data.minimal_cache_keys)
665-
model.graph_evaluation_data.minimal_cache_keys
666-
else
667-
_precompute_minimal_cache_keys(model, sorted_indices)
668-
end
667+
minimal_keys =
668+
if !isempty(gd.minimal_cache_keys) && (length(keys(gd.minimal_cache_keys)) > 0)
669+
gd.minimal_cache_keys
670+
else
671+
_precompute_minimal_cache_keys(model, sorted_indices)
672+
end
669673

670674
# Initialize memoization cache
671675
# Size hint: at most 2^|discrete_finite| * |nodes| entries
@@ -686,10 +690,10 @@ function evaluate_with_marginalization_values!!(
686690
var_lengths = Dict{VarName,Int}()
687691
for (vn, length) in model.transformed_var_lengths
688692
# Find the node index
689-
idx = findfirst(==(vn), model.graph_evaluation_data.sorted_nodes)
693+
idx = findfirst(==(vn), gd.sorted_nodes)
690694
if idx !== nothing
691695
# Only include if it's continuous (not discrete finite)
692-
if !model.graph_evaluation_data.is_discrete_finite_vals[idx]
696+
if !gd.is_discrete_finite_vals[idx]
693697
var_lengths[vn] = length
694698
end
695699
end
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using Test
2+
using JuliaBUGS
3+
using JuliaBUGS: @bugs, compile, @varname
4+
using JuliaBUGS.Model:
5+
_precompute_minimal_cache_keys, _marginalize_recursive, smart_copy_evaluation_env
6+
7+
@testset "Frontier cache for HMM under different orders" begin
8+
# Simple HMM with fixed emission parameters (no continuous params)
9+
hmm_def = @bugs begin
10+
mu[1] = 0.0
11+
mu[2] = 5.0
12+
sigma = 1.0
13+
14+
trans[1, 1] = 0.7
15+
trans[1, 2] = 0.3
16+
trans[2, 1] = 0.4
17+
trans[2, 2] = 0.6
18+
19+
pi[1] = 0.5
20+
pi[2] = 0.5
21+
22+
z[1] ~ Categorical(pi[1:2])
23+
for t in 2:T
24+
p[t, 1] = trans[z[t - 1], 1]
25+
p[t, 2] = trans[z[t - 1], 2]
26+
z[t] ~ Categorical(p[t, :])
27+
end
28+
29+
for t in 1:T
30+
y[t] ~ Normal(mu[z[t]], sigma)
31+
end
32+
end
33+
34+
T = 3
35+
data = (T=T, y=[0.1, 4.9, 5.1])
36+
model = compile(hmm_def, data)
37+
38+
gd = model.graph_evaluation_data
39+
n = length(gd.sorted_nodes)
40+
41+
# Helper: index lookup for variables of interest
42+
vn = Dict(
43+
:z1 => @varname(z[1]),
44+
:z2 => @varname(z[2]),
45+
:z3 => @varname(z[3]),
46+
:y1 => @varname(y[1]),
47+
:y2 => @varname(y[2]),
48+
:y3 => @varname(y[3]),
49+
)
50+
idx = Dict{Symbol,Int}()
51+
for (k, v) in vn
52+
i = findfirst(==(v), gd.sorted_nodes)
53+
@test i !== nothing # ensure nodes exist
54+
idx[k] = i
55+
end
56+
57+
# Construct two evaluation orders as permutations of 1:n
58+
# Interleaved: z1, y1, z2, y2, z3, y3, then the rest
59+
priority_interleaved = [idx[:z1], idx[:y1], idx[:z2], idx[:y2], idx[:z3], idx[:y3]]
60+
rest_interleaved = [i for i in 1:n if i priority_interleaved]
61+
order_interleaved = vcat(priority_interleaved, rest_interleaved)
62+
63+
# States-first: z1, z2, z3, y1, y2, y3, then the rest
64+
priority_states_first = [idx[:z1], idx[:z2], idx[:z3], idx[:y1], idx[:y2], idx[:y3]]
65+
rest_states_first = [i for i in 1:n if i priority_states_first]
66+
order_states_first = vcat(priority_states_first, rest_states_first)
67+
68+
# Precompute minimal keys for both orders
69+
keys_interleaved = _precompute_minimal_cache_keys(model, order_interleaved)
70+
keys_states_first = _precompute_minimal_cache_keys(model, order_states_first)
71+
72+
# Helper to map frontier indices back to a set of variable symbols we care about
73+
function frontier_syms(keys, key_idx)
74+
frontier = get(keys, key_idx, Int[])
75+
syms = Set{Symbol}()
76+
for (name, i) in idx
77+
if i in frontier
78+
push!(syms, name)
79+
end
80+
end
81+
return syms
82+
end
83+
84+
# Interleaved expectations: frontier size stays 1; y[t] depends on z[t]
85+
@test frontier_syms(keys_interleaved, idx[:z1]) == Set{Symbol}()
86+
@test frontier_syms(keys_interleaved, idx[:y1]) == Set([:z1])
87+
@test frontier_syms(keys_interleaved, idx[:z2]) == Set([:z1])
88+
@test frontier_syms(keys_interleaved, idx[:y2]) == Set([:z2])
89+
@test frontier_syms(keys_interleaved, idx[:z3]) == Set([:z2])
90+
@test frontier_syms(keys_interleaved, idx[:y3]) == Set([:z3])
91+
92+
# States-first expectations: frontier grows across z's, peaks at y1
93+
@test frontier_syms(keys_states_first, idx[:z1]) == Set{Symbol}()
94+
@test frontier_syms(keys_states_first, idx[:z2]) == Set([:z1])
95+
@test frontier_syms(keys_states_first, idx[:z3]) == Set([:z1, :z2])
96+
@test frontier_syms(keys_states_first, idx[:y1]) == Set([:z1, :z2, :z3])
97+
@test frontier_syms(keys_states_first, idx[:y2]) == Set([:z2, :z3])
98+
@test frontier_syms(keys_states_first, idx[:y3]) == Set([:z3])
99+
100+
# Sanity: different orders should not change marginalized log-density
101+
env = smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols)
102+
params = Float64[]
103+
memo1 = Dict{Tuple{Int,Int,UInt64},Any}()
104+
logp1 = _marginalize_recursive(
105+
model, env, order_interleaved, params, 1, Dict{Any,Int}(), memo1, keys_interleaved
106+
)
107+
108+
env2 = smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols)
109+
memo2 = Dict{Tuple{Int,Int,UInt64},Any}()
110+
logp2 = _marginalize_recursive(
111+
model,
112+
env2,
113+
order_states_first,
114+
params,
115+
1,
116+
Dict{Any,Int}(),
117+
memo2,
118+
keys_states_first,
119+
)
120+
121+
@test isapprox(logp1, logp2; atol=1e-10)
122+
123+
# And states-first should lead to equal or larger memo usage (worse frontier)
124+
@test length(memo2) >= length(memo1)
125+
end

JuliaBUGS/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ const TEST_GROUPS = OrderedDict{String,Function}(
7373
"log_density" => () -> begin
7474
include("model/evaluation.jl")
7575
include("model/auto_marginalization.jl")
76+
include("model/frontier_cache_hmm.jl")
7677
end,
7778
"inference" => () -> begin
7879
include("independent_mh.jl")

0 commit comments

Comments
 (0)