@@ -457,12 +457,12 @@ function _marginalize_recursive(
457
457
parameter_values:: AbstractVector ,
458
458
param_offsets:: Dict{VarName,Int} ,
459
459
var_lengths:: Dict{VarName,Int} ,
460
- memo:: Dict ,
460
+ memo:: Dict{Tuple{Int,Tuple,Tuple},Any} ,
461
461
minimal_keys,
462
462
)
463
463
# Base case: no more nodes to process
464
464
if isempty (remaining_indices)
465
- return zero ( eltype (parameter_values))
465
+ return 0.0 , 0.0
466
466
end
467
467
468
468
current_idx = remaining_indices[1 ]
@@ -479,14 +479,16 @@ function _marginalize_recursive(
479
479
AbstractPPL. get (env, model. graph_evaluation_data. sorted_nodes[idx]) for
480
480
idx in discrete_frontier_indices
481
481
]
482
- minimal_hash = hash ((discrete_frontier_indices, frontier_values))
482
+ frontier_indices_tuple = Tuple (discrete_frontier_indices)
483
+ frontier_values_tuple = Tuple (frontier_values)
483
484
else
484
- minimal_hash = UInt64 (0 ) # Empty frontier
485
+ frontier_indices_tuple = ()
486
+ frontier_values_tuple = ()
485
487
end
486
488
# With parameter access keyed by variable name, results depend only on the
487
489
# current node and the discrete frontier state. Continuous parameters are
488
490
# global and constant for a given input vector.
489
- memo_key = (current_idx, minimal_hash )
491
+ memo_key = (current_idx, frontier_indices_tuple, frontier_values_tuple )
490
492
491
493
if haskey (memo, memo_key)
492
494
return memo[memo_key]
@@ -498,11 +500,14 @@ function _marginalize_recursive(
498
500
node_function = model. graph_evaluation_data. node_function_vals[current_idx]
499
501
loop_vars = model. graph_evaluation_data. loop_vars_vals[current_idx]
500
502
503
+ result_prior = 0.0
504
+ result_lik = 0.0
505
+
501
506
if ! is_stochastic
502
507
# Deterministic node
503
508
value = node_function (env, loop_vars)
504
509
new_env = BangBang. setindex!! (env, value, current_vn)
505
- result = _marginalize_recursive (
510
+ result_prior, result_lik = _marginalize_recursive (
506
511
model,
507
512
new_env,
508
513
@view (remaining_indices[2 : end ]),
@@ -524,7 +529,7 @@ function _marginalize_recursive(
524
529
obs_logp = - Inf
525
530
end
526
531
527
- remaining_logp = _marginalize_recursive (
532
+ rest_prior, rest_lik = _marginalize_recursive (
528
533
model,
529
534
env,
530
535
@view (remaining_indices[2 : end ]),
@@ -534,16 +539,16 @@ function _marginalize_recursive(
534
539
memo,
535
540
minimal_keys,
536
541
)
537
- result = obs_logp + remaining_logp
542
+ result_prior = rest_prior
543
+ result_lik = obs_logp + rest_lik
538
544
539
545
elseif is_discrete_finite
540
546
# Discrete finite unobserved node - marginalize out
541
547
dist = node_function (env, loop_vars)
542
548
possible_values = enumerate_discrete_values (dist)
543
549
544
- logp_branches = Vector {typeof(zero(eltype(parameter_values)))} (
545
- undef, length (possible_values)
546
- )
550
+ total_logpriors = nothing
551
+ branch_loglikelihoods = nothing
547
552
548
553
for (i, value) in enumerate (possible_values)
549
554
branch_env = BangBang. setindex!! (env, value, current_vn)
@@ -553,7 +558,7 @@ function _marginalize_recursive(
553
558
value_logp = - Inf
554
559
end
555
560
556
- remaining_logp = _marginalize_recursive (
561
+ branch_prior, branch_lik = _marginalize_recursive (
557
562
model,
558
563
branch_env,
559
564
@view (remaining_indices[2 : end ]),
@@ -564,10 +569,26 @@ function _marginalize_recursive(
564
569
minimal_keys,
565
570
)
566
571
567
- logp_branches[i] = value_logp + remaining_logp
572
+ total_val = value_logp + branch_prior
573
+ lik_val = branch_lik
574
+ if total_logpriors === nothing
575
+ total_logpriors = Vector {typeof(total_val)} (undef, length (possible_values))
576
+ branch_loglikelihoods = Vector {typeof(lik_val)} (undef, length (possible_values))
577
+ end
578
+ total_logpriors[i] = total_val
579
+ branch_loglikelihoods[i] = lik_val
568
580
end
569
581
570
- result = LogExpFunctions. logsumexp (logp_branches)
582
+ @assert total_logpriors != = nothing && branch_loglikelihoods != = nothing
583
+ log_prior_total = LogExpFunctions. logsumexp (total_logpriors)
584
+ log_joint_total = LogExpFunctions. logsumexp (total_logpriors .+ branch_loglikelihoods)
585
+ if isfinite (log_prior_total)
586
+ result_prior = log_prior_total
587
+ result_lik = log_joint_total - log_prior_total
588
+ else
589
+ result_prior = log_prior_total
590
+ result_lik = log_joint_total
591
+ end
571
592
572
593
else
573
594
# Continuous or discrete infinite unobserved node - use parameter values
@@ -609,7 +630,7 @@ function _marginalize_recursive(
609
630
dist_logp += logjac
610
631
end
611
632
612
- remaining_logp = _marginalize_recursive (
633
+ rest_prior, rest_lik = _marginalize_recursive (
613
634
model,
614
635
new_env,
615
636
@view (remaining_indices[2 : end ]),
@@ -620,11 +641,12 @@ function _marginalize_recursive(
620
641
minimal_keys,
621
642
)
622
643
623
- result = dist_logp + remaining_logp
644
+ result_prior = dist_logp + rest_prior
645
+ result_lik = rest_lik
624
646
end
625
647
626
- memo[memo_key] = result
627
- return result
648
+ memo[memo_key] = (result_prior, result_lik)
649
+ return result_prior, result_lik
628
650
end
629
651
630
652
"""
@@ -751,7 +773,7 @@ function evaluate_with_marginalization_values!!(
751
773
else
752
774
min ((1 << n_discrete_finite) * n, 1_000_000 )
753
775
end
754
- memo = Dict {Tuple{Int,UInt64 },Any} ()
776
+ memo = Dict {Tuple{Int,Tuple,Tuple },Any} ()
755
777
sizehint! (memo, expected_entries)
756
778
757
779
# Start recursive evaluation
@@ -777,7 +799,7 @@ function evaluate_with_marginalization_values!!(
777
799
start += var_lengths[vn]
778
800
end
779
801
780
- logp = _marginalize_recursive (
802
+ log_prior, log_likelihood = _marginalize_recursive (
781
803
model,
782
804
evaluation_env,
783
805
sorted_indices,
@@ -792,8 +814,8 @@ function evaluate_with_marginalization_values!!(
792
814
# and split the log probability (though marginalization combines them)
793
815
return evaluation_env,
794
816
(
795
- logprior= logp ,
796
- loglikelihood= 0.0 , # Combined in logprior for marginalization
797
- tempered_logjoint= logp * temperature,
817
+ logprior= log_prior ,
818
+ loglikelihood= log_likelihood,
819
+ tempered_logjoint= log_prior + temperature * log_likelihood ,
798
820
)
799
821
end
0 commit comments