@@ -377,6 +377,73 @@ function _precompute_minimal_cache_keys(model::BUGSModel, order::Vector{Int})
377
377
return minimal_keys
378
378
end
379
379
380
+ """
381
+ _compute_marginalization_order(model::BUGSModel) -> Vector{Int}
382
+
383
+ Compute a topologically-valid evaluation order that reduces the frontier size
384
+ by placing discrete finite variables immediately before their observed dependents
385
+ whenever possible. This greatly reduces branching in the recursive enumerator.
386
+ """
387
+ function _compute_marginalization_order (model:: BUGSModel )
388
+ gd = model. graph_evaluation_data
389
+ n = length (gd. sorted_nodes)
390
+
391
+ # Mapping VarName <-> index in sorted_nodes
392
+ order = gd. sorted_nodes
393
+ pos = Dict (order[i] => i for i in 1 : n)
394
+
395
+ # Direct parents via graph (for topo validity)
396
+ function parents (vn)
397
+ return collect (MetaGraphsNext. inneighbor_labels (model. g, vn))
398
+ end
399
+
400
+ # Keep track of which nodes are placed
401
+ placed = fill (false , n)
402
+ out = Int[]
403
+
404
+ # Recursive placer that ensures all parents are placed first
405
+ function place_with_dependencies (vn:: VarName )
406
+ i = pos[vn]
407
+ if placed[i]
408
+ return
409
+ end
410
+ # Place all direct parents first
411
+ for p in parents (vn)
412
+ place_with_dependencies (p)
413
+ end
414
+ push! (out, i)
415
+ placed[i] = true
416
+ end
417
+
418
+ # Identify observed stochastic nodes and their discrete-finite parents (via stochastic boundary)
419
+ # We use the existing helper to traverse through deterministic nodes
420
+ stoch_parents = _get_stochastic_parents_indices (model)
421
+
422
+ # First, for each observed stochastic node, place its discrete-finite parents
423
+ # (and dependencies) immediately before placing the node itself.
424
+ for (i, vn) in enumerate (order)
425
+ if gd. is_stochastic_vals[i] && gd. is_observed_vals[i]
426
+ # Place discrete-finite unobserved parents (by label index -> VarName)
427
+ for pidx in stoch_parents[i]
428
+ if gd. is_discrete_finite_vals[pidx] && ! gd. is_observed_vals[pidx]
429
+ place_with_dependencies (order[pidx])
430
+ end
431
+ end
432
+ # Then place the observed node itself (ensures mu/sigma/etc. also placed)
433
+ place_with_dependencies (vn)
434
+ end
435
+ end
436
+
437
+ # Finally, place any remaining nodes in topological order
438
+ for vn in order
439
+ if ! placed[pos[vn]]
440
+ place_with_dependencies (vn)
441
+ end
442
+ end
443
+
444
+ return out
445
+ end
446
+
380
447
"""
381
448
_marginalize_recursive(model, env, remaining_indices, parameter_values, param_idx,
382
449
var_lengths, memo, minimal_keys)
@@ -388,8 +455,8 @@ function _marginalize_recursive(
388
455
env:: NamedTuple ,
389
456
remaining_indices:: AbstractVector{Int} ,
390
457
parameter_values:: AbstractVector ,
391
- param_idx :: Int ,
392
- var_lengths:: Dict ,
458
+ param_offsets :: Dict{VarName, Int} ,
459
+ var_lengths:: Dict{VarName,Int} ,
393
460
memo:: Dict ,
394
461
minimal_keys,
395
462
)
@@ -416,7 +483,10 @@ function _marginalize_recursive(
416
483
else
417
484
minimal_hash = UInt64 (0 ) # Empty frontier
418
485
end
419
- memo_key = (current_idx, param_idx, minimal_hash)
486
+ # With parameter access keyed by variable name, results depend only on the
487
+ # current node and the discrete frontier state. Continuous parameters are
488
+ # global and constant for a given input vector.
489
+ memo_key = (current_idx, minimal_hash)
420
490
421
491
if haskey (memo, memo_key)
422
492
return memo[memo_key]
@@ -437,7 +507,7 @@ function _marginalize_recursive(
437
507
new_env,
438
508
@view (remaining_indices[2 : end ]),
439
509
parameter_values,
440
- param_idx ,
510
+ param_offsets ,
441
511
var_lengths,
442
512
memo,
443
513
minimal_keys,
@@ -459,7 +529,7 @@ function _marginalize_recursive(
459
529
env,
460
530
@view (remaining_indices[2 : end ]),
461
531
parameter_values,
462
- param_idx ,
532
+ param_offsets ,
463
533
var_lengths,
464
534
memo,
465
535
minimal_keys,
@@ -488,7 +558,7 @@ function _marginalize_recursive(
488
558
branch_env,
489
559
@view (remaining_indices[2 : end ]),
490
560
parameter_values,
491
- param_idx ,
561
+ param_offsets ,
492
562
var_lengths,
493
563
memo,
494
564
minimal_keys,
@@ -512,16 +582,20 @@ function _marginalize_recursive(
512
582
end
513
583
514
584
l = var_lengths[current_vn]
515
-
516
- if param_idx + l - 1 > length (parameter_values)
585
+ # Fetch the start position for this variable from the precomputed map
586
+ start_idx = get (param_offsets, current_vn, 0 )
587
+ if start_idx == 0
588
+ error (" Missing parameter offset for variable '$(current_vn) '." )
589
+ end
590
+ if start_idx + l - 1 > length (parameter_values)
517
591
error (
518
- " Parameter index out of bounds: needed $(param_idx + l - 1 ) elements, " *
592
+ " Parameter index out of bounds: needed $(start_idx + l - 1 ) elements, " *
519
593
" but parameter_values has only $(length (parameter_values)) elements." ,
520
594
)
521
595
end
522
596
523
597
b_inv = Bijectors. inverse (b)
524
- param_slice = view (parameter_values, param_idx : (param_idx + l - 1 ))
598
+ param_slice = view (parameter_values, start_idx : (start_idx + l - 1 ))
525
599
526
600
reconstructed_value = reconstruct (b_inv, dist, param_slice)
527
601
value, logjac = Bijectors. with_logabsdet_jacobian (b_inv, reconstructed_value)
@@ -535,13 +609,12 @@ function _marginalize_recursive(
535
609
dist_logp += logjac
536
610
end
537
611
538
- next_idx = param_idx + l
539
612
remaining_logp = _marginalize_recursive (
540
613
model,
541
614
new_env,
542
615
@view (remaining_indices[2 : end ]),
543
616
parameter_values,
544
- next_idx ,
617
+ param_offsets ,
545
618
var_lengths,
546
619
memo,
547
620
minimal_keys,
@@ -658,18 +731,14 @@ function evaluate_with_marginalization_values!!(
658
731
)
659
732
end
660
733
661
- # Get indices for evaluation order (default to stored marginalization order )
734
+ # Compute an order that minimizes frontier growth (interleave discrete parents before observed children )
662
735
gd = model. graph_evaluation_data
663
736
n = length (gd. sorted_nodes)
664
- sorted_indices = collect ( 1 : n )
737
+ sorted_indices = _compute_marginalization_order (model )
665
738
666
- # Use precomputed minimal cache keys if available; otherwise compute once for this call
667
- minimal_keys =
668
- if ! isempty (gd. minimal_cache_keys) && (length (keys (gd. minimal_cache_keys)) > 0 )
669
- gd. minimal_cache_keys
670
- else
671
- _precompute_minimal_cache_keys (model, sorted_indices)
672
- end
739
+ # Compute minimal cache keys for this specific order
740
+ # (do not reuse cached keys if they were built for a different order)
741
+ minimal_keys = _precompute_minimal_cache_keys (model, sorted_indices)
673
742
674
743
# Initialize memoization cache
675
744
# Size hint: at most 2^|discrete_finite| * |nodes| entries
@@ -679,7 +748,7 @@ function evaluate_with_marginalization_values!!(
679
748
else
680
749
min ((1 << n_discrete_finite) * n, 1_000_000 )
681
750
end
682
- memo = Dict {Tuple{Int,Int, UInt64},Any} ()
751
+ memo = Dict {Tuple{Int,UInt64},Any} ()
683
752
sizehint! (memo, expected_entries)
684
753
685
754
# Start recursive evaluation
@@ -688,23 +757,29 @@ function evaluate_with_marginalization_values!!(
688
757
# For marginalization, only continuous parameters need var_lengths
689
758
# Discrete finite variables are marginalized over, not sampled
690
759
var_lengths = Dict {VarName,Int} ()
691
- for (vn, length) in model . transformed_var_lengths
692
- # Find the node index
760
+ continuous_param_order = VarName[]
761
+ for vn in gd . sorted_parameters
693
762
idx = findfirst (== (vn), gd. sorted_nodes)
694
- if idx != = nothing
695
- # Only include if it's continuous (not discrete finite)
696
- if ! gd. is_discrete_finite_vals[idx]
697
- var_lengths[vn] = length
698
- end
763
+ if idx != = nothing && gd. node_types[idx] == :continuous
764
+ push! (continuous_param_order, vn)
765
+ var_lengths[vn] = model. transformed_var_lengths[vn]
699
766
end
700
767
end
701
768
769
+ # Build mapping from variable -> start index in flattened_values
770
+ param_offsets = Dict {VarName,Int} ()
771
+ start = 1
772
+ for vn in continuous_param_order
773
+ param_offsets[vn] = start
774
+ start += var_lengths[vn]
775
+ end
776
+
702
777
logp = _marginalize_recursive (
703
778
model,
704
779
evaluation_env,
705
780
sorted_indices,
706
781
flattened_values,
707
- 1 ,
782
+ param_offsets ,
708
783
var_lengths,
709
784
memo,
710
785
minimal_keys,
0 commit comments