@@ -303,54 +303,73 @@ function _precompute_minimal_cache_keys(model::BUGSModel, order::Vector{Int})
303
303
is_discrete_finite = gd. is_discrete_finite_vals
304
304
node_types = gd. node_types
305
305
306
- # Get stochastic parents for each node
306
+ # Get stochastic parents (stochastic boundary) for each node
307
307
parents_idx = _get_stochastic_parents_indices (model)
308
308
309
- # Initialize frontier keys for each position
310
- minimal_keys = Dict {Int,Vector{Int}} ()
309
+ # Build mapping from node index (in gd.sorted_nodes) -> position in the provided order.
310
+ # This lets us reason about liveness w.r.t. the chosen evaluation order.
311
+ order_pos = Vector {Int} (undef, length (gd. sorted_nodes))
312
+ @inbounds for k in 1 : n
313
+ order_pos[order[k]] = k
314
+ end
311
315
312
- # For each node, determine which discrete finite variables should be in its frontier
313
- for k in 1 : n
314
- current_idx = order[k]
315
-
316
- # The frontier should include discrete finite variables that:
317
- # 1. Come before this node in the evaluation order (have been set)
318
- # 2. This node depends on (directly or indirectly)
319
-
320
- filtered_frontier = Int[]
321
-
322
- # For ALL nodes (stochastic and deterministic), we need to track dependencies
323
- # on discrete finite variables that have been set
324
-
325
- if node_types[current_idx] == :discrete_finite
326
- # IMPORTANT: At a discrete node, the remainder of computation may include
327
- # observed likelihood terms that depend on ALL earlier discrete assignments
328
- # (e.g., y[1..t-1] in an HMM when all y's are evaluated after all z's).
329
- # Therefore, to avoid incorrect cache reuse, conservatively include ALL
330
- # prior discrete finite unobserved variables in the frontier.
331
- for j in 1 : (k - 1 )
332
- idx = order[j]
333
- if is_discrete_finite[idx] && ! is_observed[idx]
334
- push! (filtered_frontier, idx)
316
+ # Compute last-use POSITIONS (w.r.t. 'order') for each unobserved finite-discrete variable.
317
+ # A variable stays in the frontier until we pass the last stochastic node
318
+ # (observed or unobserved) whose distribution depends on it.
319
+ last_use_pos = Dict {Int,Int} () # map from variable index -> last position in 'order'
320
+ for j_label in 1 : length (gd. sorted_nodes)
321
+ if gd. is_stochastic_vals[j_label]
322
+ j_pos = order_pos[j_label]
323
+ for p_label in parents_idx[j_label]
324
+ if is_discrete_finite[p_label] && ! is_observed[p_label]
325
+ # Default to the position of the variable itself if unseen
326
+ default_pos = order_pos[p_label]
327
+ last_use_pos[p_label] = max (get (last_use_pos, p_label, default_pos), j_pos)
335
328
end
336
329
end
337
- else
338
- # For deterministic, continuous, or discrete-infinite nodes, we must ensure
339
- # the memo key distinguishes different earlier discrete assignments that can
340
- # influence any of the remaining computation (e.g., later likelihood terms).
341
- # A conservative but correct choice is to include ALL prior discrete finite
342
- # unobserved variables in the frontier.
343
- for j in 1 : (k - 1 )
344
- idx = order[j]
345
- if is_discrete_finite[idx] && ! is_observed[idx]
346
- push! (filtered_frontier, idx)
347
- end
330
+ end
331
+ end
332
+
333
+ # Initialize frontier keys for each position based on liveness
334
+ # Optimized incremental construction to avoid O(n^2) in common patterns
335
+ minimal_keys = Dict {Int,Vector{Int}} ()
336
+
337
+ # Precompute starts and ends in order positions
338
+ starts_at = Dict {Int,Vector{Int}} ()
339
+ for lbl in 1 : length (gd. sorted_nodes)
340
+ pos = order_pos[lbl]
341
+ if is_discrete_finite[lbl] && ! is_observed[lbl]
342
+ push! (get! (starts_at, pos, Int[]), lbl)
343
+ end
344
+ end
345
+
346
+ # Active set of earlier discrete finite variables (by label index)
347
+ active = Int[]
348
+ # Track end positions for active labels
349
+ function purge_expired! (active_vec:: Vector{Int} , k_pos:: Int )
350
+ # Remove any with last_use_pos < k_pos
351
+ i = 1
352
+ while i <= length (active_vec)
353
+ lbl = active_vec[i]
354
+ if get (last_use_pos, lbl, 0 ) < k_pos
355
+ deleteat! (active_vec, i)
356
+ else
357
+ i += 1
348
358
end
349
359
end
360
+ return active_vec
361
+ end
350
362
351
- # Store as sorted vector
352
- sort! (unique! (filtered_frontier))
353
- minimal_keys[current_idx] = filtered_frontier
363
+ for k in 1 : n
364
+ # Add labels that start at previous position so they count as "earlier"
365
+ if haskey (starts_at, k - 1 )
366
+ append! (active, starts_at[k - 1 ])
367
+ end
368
+ # Drop any labels that have expired before current position
369
+ purge_expired! (active, k)
370
+ # Sort for stable key representation
371
+ sort! (active)
372
+ minimal_keys[order[k]] = copy (active)
354
373
end
355
374
356
375
return minimal_keys
@@ -409,7 +428,7 @@ function _marginalize_recursive(
409
428
410
429
if ! is_stochastic
411
430
# Deterministic node
412
- value = node_function ( env, loop_vars)
431
+ value = Base . invokelatest (node_function, env, loop_vars)
413
432
new_env = BangBang. setindex!! (env, value, current_vn)
414
433
result = _marginalize_recursive (
415
434
model,
@@ -424,7 +443,7 @@ function _marginalize_recursive(
424
443
425
444
elseif is_observed
426
445
# Observed stochastic node
427
- dist = node_function ( env, loop_vars)
446
+ dist = Base . invokelatest (node_function, env, loop_vars)
428
447
obs_value = AbstractPPL. get (env, current_vn)
429
448
obs_logp = logpdf (dist, obs_value)
430
449
@@ -447,7 +466,7 @@ function _marginalize_recursive(
447
466
448
467
elseif is_discrete_finite
449
468
# Discrete finite unobserved node - marginalize out
450
- dist = node_function ( env, loop_vars)
469
+ dist = Base . invokelatest (node_function, env, loop_vars)
451
470
possible_values = enumerate_discrete_values (dist)
452
471
453
472
logp_branches = Vector {typeof(zero(eltype(parameter_values)))} (
@@ -480,7 +499,7 @@ function _marginalize_recursive(
480
499
481
500
else
482
501
# Continuous or discrete infinite unobserved node - use parameter values
483
- dist = node_function ( env, loop_vars)
502
+ dist = Base . invokelatest (node_function, env, loop_vars)
484
503
b = Bijectors. bijector (dist)
485
504
486
505
if ! haskey (var_lengths, current_vn)
@@ -641,8 +660,12 @@ function evaluate_with_marginalization_values!!(
641
660
n = length (model. graph_evaluation_data. sorted_nodes)
642
661
sorted_indices = collect (1 : n)
643
662
644
- # Precompute minimal cache keys
645
- minimal_keys = _precompute_minimal_cache_keys (model, sorted_indices)
663
+ # Use precomputed minimal cache keys if available; otherwise compute once for this call
664
+ minimal_keys = if ! isempty (model. graph_evaluation_data. minimal_cache_keys)
665
+ model. graph_evaluation_data. minimal_cache_keys
666
+ else
667
+ _precompute_minimal_cache_keys (model, sorted_indices)
668
+ end
646
669
647
670
# Initialize memoization cache
648
671
# Size hint: at most 2^|discrete_finite| * |nodes| entries
0 commit comments