Skip to content

Commit d8874e4

Browse files
committed
fix: stabilize auto-marginalization caches and tempering
1 parent f6002b3 commit d8874e4

File tree

3 files changed

+102
-27
lines changed

3 files changed

+102
-27
lines changed

JuliaBUGS/src/model/abstractppl.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,10 @@ function _create_modified_model(
622622
order,
623623
)
624624
return BUGSModel(new_model; graph_evaluation_data=gd_cached)
625-
catch
626-
# If caches cannot be computed (e.g., unsupported model), return model as-is.
627-
return new_model
625+
catch err
626+
@warn "Failed to precompute auto-marginalization caches; falling back to graph evaluation" exception=(err, catch_backtrace())
627+
# Ensure the regenerated model does not stay in an inconsistent evaluation mode
628+
return BangBang.setproperty!!(new_model, :evaluation_mode, UseGraph())
628629
end
629630
end
630631

JuliaBUGS/src/model/evaluation.jl

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,12 @@ function _marginalize_recursive(
457457
parameter_values::AbstractVector,
458458
param_offsets::Dict{VarName,Int},
459459
var_lengths::Dict{VarName,Int},
460-
memo::Dict,
460+
memo::Dict{Tuple{Int,Tuple,Tuple},Any},
461461
minimal_keys,
462462
)
463463
# Base case: no more nodes to process
464464
if isempty(remaining_indices)
465-
return zero(eltype(parameter_values))
465+
return 0.0, 0.0
466466
end
467467

468468
current_idx = remaining_indices[1]
@@ -479,14 +479,16 @@ function _marginalize_recursive(
479479
AbstractPPL.get(env, model.graph_evaluation_data.sorted_nodes[idx]) for
480480
idx in discrete_frontier_indices
481481
]
482-
minimal_hash = hash((discrete_frontier_indices, frontier_values))
482+
frontier_indices_tuple = Tuple(discrete_frontier_indices)
483+
frontier_values_tuple = Tuple(frontier_values)
483484
else
484-
minimal_hash = UInt64(0) # Empty frontier
485+
frontier_indices_tuple = ()
486+
frontier_values_tuple = ()
485487
end
486488
# With parameter access keyed by variable name, results depend only on the
487489
# current node and the discrete frontier state. Continuous parameters are
488490
# global and constant for a given input vector.
489-
memo_key = (current_idx, minimal_hash)
491+
memo_key = (current_idx, frontier_indices_tuple, frontier_values_tuple)
490492

491493
if haskey(memo, memo_key)
492494
return memo[memo_key]
@@ -498,11 +500,14 @@ function _marginalize_recursive(
498500
node_function = model.graph_evaluation_data.node_function_vals[current_idx]
499501
loop_vars = model.graph_evaluation_data.loop_vars_vals[current_idx]
500502

503+
result_prior = 0.0
504+
result_lik = 0.0
505+
501506
if !is_stochastic
502507
# Deterministic node
503508
value = node_function(env, loop_vars)
504509
new_env = BangBang.setindex!!(env, value, current_vn)
505-
result = _marginalize_recursive(
510+
result_prior, result_lik = _marginalize_recursive(
506511
model,
507512
new_env,
508513
@view(remaining_indices[2:end]),
@@ -524,7 +529,7 @@ function _marginalize_recursive(
524529
obs_logp = -Inf
525530
end
526531

527-
remaining_logp = _marginalize_recursive(
532+
rest_prior, rest_lik = _marginalize_recursive(
528533
model,
529534
env,
530535
@view(remaining_indices[2:end]),
@@ -534,16 +539,16 @@ function _marginalize_recursive(
534539
memo,
535540
minimal_keys,
536541
)
537-
result = obs_logp + remaining_logp
542+
result_prior = rest_prior
543+
result_lik = obs_logp + rest_lik
538544

539545
elseif is_discrete_finite
540546
# Discrete finite unobserved node - marginalize out
541547
dist = node_function(env, loop_vars)
542548
possible_values = enumerate_discrete_values(dist)
543549

544-
logp_branches = Vector{typeof(zero(eltype(parameter_values)))}(
545-
undef, length(possible_values)
546-
)
550+
total_logpriors = nothing
551+
branch_loglikelihoods = nothing
547552

548553
for (i, value) in enumerate(possible_values)
549554
branch_env = BangBang.setindex!!(env, value, current_vn)
@@ -553,7 +558,7 @@ function _marginalize_recursive(
553558
value_logp = -Inf
554559
end
555560

556-
remaining_logp = _marginalize_recursive(
561+
branch_prior, branch_lik = _marginalize_recursive(
557562
model,
558563
branch_env,
559564
@view(remaining_indices[2:end]),
@@ -564,10 +569,26 @@ function _marginalize_recursive(
564569
minimal_keys,
565570
)
566571

567-
logp_branches[i] = value_logp + remaining_logp
572+
total_val = value_logp + branch_prior
573+
lik_val = branch_lik
574+
if total_logpriors === nothing
575+
total_logpriors = Vector{typeof(total_val)}(undef, length(possible_values))
576+
branch_loglikelihoods = Vector{typeof(lik_val)}(undef, length(possible_values))
577+
end
578+
total_logpriors[i] = total_val
579+
branch_loglikelihoods[i] = lik_val
568580
end
569581

570-
result = LogExpFunctions.logsumexp(logp_branches)
582+
@assert total_logpriors !== nothing && branch_loglikelihoods !== nothing
583+
log_prior_total = LogExpFunctions.logsumexp(total_logpriors)
584+
log_joint_total = LogExpFunctions.logsumexp(total_logpriors .+ branch_loglikelihoods)
585+
if isfinite(log_prior_total)
586+
result_prior = log_prior_total
587+
result_lik = log_joint_total - log_prior_total
588+
else
589+
result_prior = log_prior_total
590+
result_lik = log_joint_total
591+
end
571592

572593
else
573594
# Continuous or discrete infinite unobserved node - use parameter values
@@ -609,7 +630,7 @@ function _marginalize_recursive(
609630
dist_logp += logjac
610631
end
611632

612-
remaining_logp = _marginalize_recursive(
633+
rest_prior, rest_lik = _marginalize_recursive(
613634
model,
614635
new_env,
615636
@view(remaining_indices[2:end]),
@@ -620,11 +641,12 @@ function _marginalize_recursive(
620641
minimal_keys,
621642
)
622643

623-
result = dist_logp + remaining_logp
644+
result_prior = dist_logp + rest_prior
645+
result_lik = rest_lik
624646
end
625647

626-
memo[memo_key] = result
627-
return result
648+
memo[memo_key] = (result_prior, result_lik)
649+
return result_prior, result_lik
628650
end
629651

630652
"""
@@ -751,7 +773,7 @@ function evaluate_with_marginalization_values!!(
751773
else
752774
min((1 << n_discrete_finite) * n, 1_000_000)
753775
end
754-
memo = Dict{Tuple{Int,UInt64},Any}()
776+
memo = Dict{Tuple{Int,Tuple,Tuple},Any}()
755777
sizehint!(memo, expected_entries)
756778

757779
# Start recursive evaluation
@@ -777,7 +799,7 @@ function evaluate_with_marginalization_values!!(
777799
start += var_lengths[vn]
778800
end
779801

780-
logp = _marginalize_recursive(
802+
log_prior, log_likelihood = _marginalize_recursive(
781803
model,
782804
evaluation_env,
783805
sorted_indices,
@@ -792,8 +814,8 @@ function evaluate_with_marginalization_values!!(
792814
# and split the log probability (though marginalization combines them)
793815
return evaluation_env,
794816
(
795-
logprior=logp,
796-
loglikelihood=0.0, # Combined in logprior for marginalization
797-
tempered_logjoint=logp * temperature,
817+
logprior=log_prior,
818+
loglikelihood=log_likelihood,
819+
tempered_logjoint=log_prior + temperature * log_likelihood,
798820
)
799821
end

JuliaBUGS/test/model/auto_marginalization.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
# This file is included from runtests.jl which provides all necessary imports
33

44
using JuliaBUGS: @bugs, compile, settrans, initialize!, getparams
5-
using JuliaBUGS.Model: set_evaluation_mode, UseAutoMarginalization, UseGraph
5+
using JuliaBUGS.Model:
6+
set_evaluation_mode,
7+
UseAutoMarginalization,
8+
UseGraph,
9+
evaluate_with_marginalization_values!!
610

711
@testset "Auto-Marginalization" begin
812
println("[AutoMargTest] Starting Auto-Marginalization test suite...");
@@ -707,6 +711,54 @@ using JuliaBUGS.Model: set_evaluation_mode, UseAutoMarginalization, UseGraph
707711
@test rel_err < 5e-5
708712
end
709713

714+
@testset "Log prior/likelihood split and tempering" begin
715+
println("[AutoMargTest] Log split: compiling model...");
716+
flush(stdout)
717+
simple_def = @bugs begin
718+
mu ~ Normal(0, 1)
719+
z ~ Categorical(w[1:K])
720+
y ~ Normal(mu + delta[z], sigma)
721+
end
722+
723+
data = (
724+
K=2,
725+
w=[0.3, 0.7],
726+
delta=[0.0, 2.0],
727+
sigma=1.0,
728+
y=1.5,
729+
)
730+
731+
model = compile(simple_def, data)
732+
model = settrans(model, true)
733+
model = set_evaluation_mode(model, UseAutoMarginalization())
734+
735+
θ = [0.0] # mu in transformed space (identity bijector)
736+
_, stats = evaluate_with_marginalization_values!!(model, θ; temperature=0.4)
737+
738+
expected_logprior = logpdf(Normal(0, 1), 0.0)
739+
log_weighted = [
740+
log(data.w[i]) + logpdf(Normal(0.0 + data.delta[i], data.sigma), data.y) for
741+
i in 1:data.K
742+
]
743+
expected_loglik = LogExpFunctions.logsumexp(log_weighted)
744+
745+
@test isapprox(stats.logprior, expected_logprior; atol=1e-10)
746+
@test isapprox(stats.loglikelihood, expected_loglik; atol=1e-10)
747+
@test isapprox(
748+
stats.tempered_logjoint, expected_logprior + 0.4 * expected_loglik; atol=1e-10
749+
)
750+
751+
ad_model = ADgradient(AutoForwardDiff(), model)
752+
val, grad = LogDensityProblems.logdensity_and_gradient(ad_model, θ)
753+
@test isapprox(val, expected_logprior + expected_loglik; atol=1e-10)
754+
function f_scalar(mu_val)
755+
LogDensityProblems.logdensity(model, [mu_val])
756+
end
757+
ϵ = 1e-6
758+
fd_grad = (f_scalar(θ[1] + ϵ) - f_scalar(θ[1] - ϵ)) / (2ϵ)
759+
@test isapprox(grad[1], fd_grad; atol=1e-6)
760+
end
761+
710762
@testset "Efficiency smoke: AutoMarg+NUTS vs Graph+IndependentMH" begin
711763
println("[AutoMargTest] Efficiency smoke: compiling models...");
712764
flush(stdout)

0 commit comments

Comments
 (0)