Skip to content

Commit 6de0807

Browse files
committed
fix error
1 parent ad90ec8 commit 6de0807

File tree

4 files changed

+248
-44
lines changed

4 files changed

+248
-44
lines changed

JuliaBUGS/src/model/bugsmodel.jl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,7 @@ function BUGSModel(
232232
n = length(gd.sorted_nodes)
233233
JuliaBUGS.Model._precompute_minimal_cache_keys(m, collect(1:n))
234234
end
235-
# Attach minimal order and keys to GraphEvaluationData
236-
order = if isempty(gd.marginalization_order)
237-
collect(1:length(gd.sorted_nodes))
238-
else
239-
gd.marginalization_order
240-
end
235+
# Attach minimal cache keys to GraphEvaluationData (order remains default)
241236
gd2 = GraphEvaluationData(
242237
gd.sorted_nodes,
243238
gd.sorted_parameters,
@@ -247,7 +242,6 @@ function BUGSModel(
247242
gd.loop_vars_vals,
248243
gd.node_types,
249244
gd.is_discrete_finite_vals,
250-
order,
251245
minimal_keys,
252246
)
253247
# Return final model with cached minimal keys
@@ -612,15 +606,33 @@ params_dict = getparams(Dict, model, custom_env)
612606
```
613607
"""
614608
function getparams(model::BUGSModel, evaluation_env=model.evaluation_env)
615-
param_length = if model.transformed
616-
model.transformed_param_length
609+
# Determine which parameters to include based on evaluation mode
610+
gd = model.graph_evaluation_data
611+
param_vars = if model.evaluation_mode isa UseAutoMarginalization
612+
# Only include continuous parameters when auto marginalizing
613+
filter(gd.sorted_parameters) do vn
614+
idx = findfirst(==(vn), gd.sorted_nodes)
615+
idx !== nothing && gd.node_types[idx] == :continuous
616+
end
617617
else
618-
model.untransformed_param_length
618+
gd.sorted_parameters
619+
end
620+
621+
# Compute total length for allocation
622+
param_length = 0
623+
if model.transformed
624+
for vn in param_vars
625+
param_length += model.transformed_var_lengths[vn]
626+
end
627+
else
628+
for vn in param_vars
629+
param_length += model.untransformed_var_lengths[vn]
630+
end
619631
end
620632

621633
param_vals = Vector{Float64}(undef, param_length)
622634
pos = 1
623-
for v in model.graph_evaluation_data.sorted_parameters
635+
for v in param_vars
624636
if !model.transformed
625637
val = AbstractPPL.get(evaluation_env, v)
626638
len = model.untransformed_var_lengths[v]
@@ -651,7 +663,17 @@ function getparams(
651663
T::Type{<:AbstractDict}, model::BUGSModel, evaluation_env=model.evaluation_env
652664
)
653665
d = T()
654-
for v in model.graph_evaluation_data.sorted_parameters
666+
gd = model.graph_evaluation_data
667+
# Respect evaluation mode when selecting parameters
668+
param_vars = if model.evaluation_mode isa UseAutoMarginalization
669+
filter(gd.sorted_parameters) do vn
670+
idx = findfirst(==(vn), gd.sorted_nodes)
671+
idx !== nothing && gd.node_types[idx] == :continuous
672+
end
673+
else
674+
gd.sorted_parameters
675+
end
676+
for v in param_vars
655677
value = AbstractPPL.get(evaluation_env, v)
656678
if !model.transformed
657679
d[v] = value

JuliaBUGS/src/model/evaluation.jl

Lines changed: 105 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,73 @@ function _precompute_minimal_cache_keys(model::BUGSModel, order::Vector{Int})
377377
return minimal_keys
378378
end
379379

380+
"""
381+
_compute_marginalization_order(model::BUGSModel) -> Vector{Int}
382+
383+
Compute a topologically-valid evaluation order that reduces the frontier size
384+
by placing discrete finite variables immediately before their observed dependents
385+
whenever possible. This greatly reduces branching in the recursive enumerator.
386+
"""
387+
function _compute_marginalization_order(model::BUGSModel)
388+
gd = model.graph_evaluation_data
389+
n = length(gd.sorted_nodes)
390+
391+
# Mapping VarName <-> index in sorted_nodes
392+
order = gd.sorted_nodes
393+
pos = Dict(order[i] => i for i in 1:n)
394+
395+
# Direct parents via graph (for topo validity)
396+
function parents(vn)
397+
return collect(MetaGraphsNext.inneighbor_labels(model.g, vn))
398+
end
399+
400+
# Keep track of which nodes are placed
401+
placed = fill(false, n)
402+
out = Int[]
403+
404+
# Recursive placer that ensures all parents are placed first
405+
function place_with_dependencies(vn::VarName)
406+
i = pos[vn]
407+
if placed[i]
408+
return
409+
end
410+
# Place all direct parents first
411+
for p in parents(vn)
412+
place_with_dependencies(p)
413+
end
414+
push!(out, i)
415+
placed[i] = true
416+
end
417+
418+
# Identify observed stochastic nodes and their discrete-finite parents (via stochastic boundary)
419+
# We use the existing helper to traverse through deterministic nodes
420+
stoch_parents = _get_stochastic_parents_indices(model)
421+
422+
# First, for each observed stochastic node, place its discrete-finite parents
423+
# (and dependencies) immediately before placing the node itself.
424+
for (i, vn) in enumerate(order)
425+
if gd.is_stochastic_vals[i] && gd.is_observed_vals[i]
426+
# Place discrete-finite unobserved parents (by label index -> VarName)
427+
for pidx in stoch_parents[i]
428+
if gd.is_discrete_finite_vals[pidx] && !gd.is_observed_vals[pidx]
429+
place_with_dependencies(order[pidx])
430+
end
431+
end
432+
# Then place the observed node itself (ensures mu/sigma/etc. also placed)
433+
place_with_dependencies(vn)
434+
end
435+
end
436+
437+
# Finally, place any remaining nodes in topological order
438+
for vn in order
439+
if !placed[pos[vn]]
440+
place_with_dependencies(vn)
441+
end
442+
end
443+
444+
return out
445+
end
446+
380447
"""
381448
_marginalize_recursive(model, env, remaining_indices, parameter_values, param_idx,
382449
var_lengths, memo, minimal_keys)
@@ -388,8 +455,8 @@ function _marginalize_recursive(
388455
env::NamedTuple,
389456
remaining_indices::AbstractVector{Int},
390457
parameter_values::AbstractVector,
391-
param_idx::Int,
392-
var_lengths::Dict,
458+
param_offsets::Dict{VarName,Int},
459+
var_lengths::Dict{VarName,Int},
393460
memo::Dict,
394461
minimal_keys,
395462
)
@@ -416,7 +483,10 @@ function _marginalize_recursive(
416483
else
417484
minimal_hash = UInt64(0) # Empty frontier
418485
end
419-
memo_key = (current_idx, param_idx, minimal_hash)
486+
# With parameter access keyed by variable name, results depend only on the
487+
# current node and the discrete frontier state. Continuous parameters are
488+
# global and constant for a given input vector.
489+
memo_key = (current_idx, minimal_hash)
420490

421491
if haskey(memo, memo_key)
422492
return memo[memo_key]
@@ -437,7 +507,7 @@ function _marginalize_recursive(
437507
new_env,
438508
@view(remaining_indices[2:end]),
439509
parameter_values,
440-
param_idx,
510+
param_offsets,
441511
var_lengths,
442512
memo,
443513
minimal_keys,
@@ -459,7 +529,7 @@ function _marginalize_recursive(
459529
env,
460530
@view(remaining_indices[2:end]),
461531
parameter_values,
462-
param_idx,
532+
param_offsets,
463533
var_lengths,
464534
memo,
465535
minimal_keys,
@@ -488,7 +558,7 @@ function _marginalize_recursive(
488558
branch_env,
489559
@view(remaining_indices[2:end]),
490560
parameter_values,
491-
param_idx,
561+
param_offsets,
492562
var_lengths,
493563
memo,
494564
minimal_keys,
@@ -512,16 +582,20 @@ function _marginalize_recursive(
512582
end
513583

514584
l = var_lengths[current_vn]
515-
516-
if param_idx + l - 1 > length(parameter_values)
585+
# Fetch the start position for this variable from the precomputed map
586+
start_idx = get(param_offsets, current_vn, 0)
587+
if start_idx == 0
588+
error("Missing parameter offset for variable '$(current_vn)'.")
589+
end
590+
if start_idx + l - 1 > length(parameter_values)
517591
error(
518-
"Parameter index out of bounds: needed $(param_idx + l - 1) elements, " *
592+
"Parameter index out of bounds: needed $(start_idx + l - 1) elements, " *
519593
"but parameter_values has only $(length(parameter_values)) elements.",
520594
)
521595
end
522596

523597
b_inv = Bijectors.inverse(b)
524-
param_slice = view(parameter_values, param_idx:(param_idx + l - 1))
598+
param_slice = view(parameter_values, start_idx:(start_idx + l - 1))
525599

526600
reconstructed_value = reconstruct(b_inv, dist, param_slice)
527601
value, logjac = Bijectors.with_logabsdet_jacobian(b_inv, reconstructed_value)
@@ -535,13 +609,12 @@ function _marginalize_recursive(
535609
dist_logp += logjac
536610
end
537611

538-
next_idx = param_idx + l
539612
remaining_logp = _marginalize_recursive(
540613
model,
541614
new_env,
542615
@view(remaining_indices[2:end]),
543616
parameter_values,
544-
next_idx,
617+
param_offsets,
545618
var_lengths,
546619
memo,
547620
minimal_keys,
@@ -658,18 +731,14 @@ function evaluate_with_marginalization_values!!(
658731
)
659732
end
660733

661-
# Get indices for evaluation order (default to stored marginalization order)
734+
# Compute an order that minimizes frontier growth (interleave discrete parents before observed children)
662735
gd = model.graph_evaluation_data
663736
n = length(gd.sorted_nodes)
664-
sorted_indices = collect(1:n)
737+
sorted_indices = _compute_marginalization_order(model)
665738

666-
# Use precomputed minimal cache keys if available; otherwise compute once for this call
667-
minimal_keys =
668-
if !isempty(gd.minimal_cache_keys) && (length(keys(gd.minimal_cache_keys)) > 0)
669-
gd.minimal_cache_keys
670-
else
671-
_precompute_minimal_cache_keys(model, sorted_indices)
672-
end
739+
# Compute minimal cache keys for this specific order
740+
# (do not reuse cached keys if they were built for a different order)
741+
minimal_keys = _precompute_minimal_cache_keys(model, sorted_indices)
673742

674743
# Initialize memoization cache
675744
# Size hint: at most 2^|discrete_finite| * |nodes| entries
@@ -679,7 +748,7 @@ function evaluate_with_marginalization_values!!(
679748
else
680749
min((1 << n_discrete_finite) * n, 1_000_000)
681750
end
682-
memo = Dict{Tuple{Int,Int,UInt64},Any}()
751+
memo = Dict{Tuple{Int,UInt64},Any}()
683752
sizehint!(memo, expected_entries)
684753

685754
# Start recursive evaluation
@@ -688,23 +757,29 @@ function evaluate_with_marginalization_values!!(
688757
# For marginalization, only continuous parameters need var_lengths
689758
# Discrete finite variables are marginalized over, not sampled
690759
var_lengths = Dict{VarName,Int}()
691-
for (vn, length) in model.transformed_var_lengths
692-
# Find the node index
760+
continuous_param_order = VarName[]
761+
for vn in gd.sorted_parameters
693762
idx = findfirst(==(vn), gd.sorted_nodes)
694-
if idx !== nothing
695-
# Only include if it's continuous (not discrete finite)
696-
if !gd.is_discrete_finite_vals[idx]
697-
var_lengths[vn] = length
698-
end
763+
if idx !== nothing && gd.node_types[idx] == :continuous
764+
push!(continuous_param_order, vn)
765+
var_lengths[vn] = model.transformed_var_lengths[vn]
699766
end
700767
end
701768

769+
# Build mapping from variable -> start index in flattened_values
770+
param_offsets = Dict{VarName,Int}()
771+
start = 1
772+
for vn in continuous_param_order
773+
param_offsets[vn] = start
774+
start += var_lengths[vn]
775+
end
776+
702777
logp = _marginalize_recursive(
703778
model,
704779
evaluation_env,
705780
sorted_indices,
706781
flattened_values,
707-
1,
782+
param_offsets,
708783
var_lengths,
709784
memo,
710785
minimal_keys,

0 commit comments

Comments
 (0)