Skip to content

Commit 7243c23

Browse files
committed
fix some errors
1 parent 6b1206e commit 7243c23

File tree

2 files changed

+152
-50
lines changed

2 files changed

+152
-50
lines changed

JuliaBUGS/src/model/bugsmodel.jl

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ struct GraphEvaluationData{TNF,TV}
7676
loop_vars_vals::TV
7777
node_types::Vector{Symbol}
7878
is_discrete_finite_vals::Vector{Bool}
79+
minimal_cache_keys::Dict{Int,Vector{Int}}
7980
end
8081

8182
function GraphEvaluationData(
@@ -122,6 +123,7 @@ function GraphEvaluationData(
122123
map(identity, loop_vars_vals),
123124
node_types,
124125
is_discrete_finite_vals,
126+
Dict{Int,Vector{Int}}(),
125127
)
126128
end
127129

@@ -205,7 +207,8 @@ function BUGSModel(
205207
model_def::Expr=model.model_def,
206208
data=model.data,
207209
)
208-
return BUGSModel(
210+
# Build an intermediate model
211+
m = BUGSModel(
209212
model_def,
210213
data,
211214
g,
@@ -221,6 +224,43 @@ function BUGSModel(
221224
mutable_symbols,
222225
base_model,
223226
)
227+
# Precompute minimal cache keys for current evaluation order if not present
228+
gd = m.graph_evaluation_data
229+
minimal_keys = if !isempty(gd.minimal_cache_keys)
230+
gd.minimal_cache_keys
231+
else
232+
n = length(gd.sorted_nodes)
233+
JuliaBUGS.Model._precompute_minimal_cache_keys(m, collect(1:n))
234+
end
235+
# Attach minimal keys to GraphEvaluationData
236+
gd2 = GraphEvaluationData(
237+
gd.sorted_nodes,
238+
gd.sorted_parameters,
239+
gd.is_stochastic_vals,
240+
gd.is_observed_vals,
241+
gd.node_function_vals,
242+
gd.loop_vars_vals,
243+
gd.node_types,
244+
gd.is_discrete_finite_vals,
245+
minimal_keys,
246+
)
247+
# Return final model with cached minimal keys
248+
return BUGSModel(
249+
model_def,
250+
data,
251+
g,
252+
evaluation_env,
253+
transformed,
254+
evaluation_mode,
255+
untransformed_param_length,
256+
transformed_param_length,
257+
untransformed_var_lengths,
258+
transformed_var_lengths,
259+
gd2,
260+
log_density_computation_function,
261+
mutable_symbols,
262+
base_model,
263+
)
224264
end
225265

226266
function Base.show(io::IO, model::BUGSModel)
@@ -337,6 +377,7 @@ function BUGSModel(
337377
graph_evaluation_data.loop_vars_vals,
338378
node_types,
339379
is_discrete_finite_vals,
380+
Dict{Int,Vector{Int}}(),
340381
)
341382

342383
lowered_model_def, reconstructed_model_def = JuliaBUGS._generate_lowered_model_def(
@@ -397,6 +438,7 @@ function BUGSModel(
397438
new_gd.loop_vars_vals,
398439
new_node_types,
399440
new_is_discrete_finite_vals,
441+
Dict{Int,Vector{Int}}(),
400442
)
401443
else
402444
log_density_computation_function = nothing
@@ -405,7 +447,8 @@ function BUGSModel(
405447
# Compute mutable symbols from graph evaluation data
406448
mutable_symbols = get_mutable_symbols(graph_evaluation_data)
407449

408-
return BUGSModel(
450+
# Build initial model (without minimal cache keys precomputed)
451+
model_without_min_keys = BUGSModel(
409452
model_def,
410453
data,
411454
g,
@@ -421,6 +464,42 @@ function BUGSModel(
421464
mutable_symbols,
422465
nothing,
423466
)
467+
# Precompute minimal cache keys for the default order (1:n)
468+
n = length(graph_evaluation_data.sorted_nodes)
469+
sorted_indices = collect(1:n)
470+
minimal_keys = JuliaBUGS.Model._precompute_minimal_cache_keys(
471+
model_without_min_keys, sorted_indices
472+
)
473+
# Attach minimal keys to GraphEvaluationData
474+
graph_evaluation_data_with_keys = GraphEvaluationData(
475+
graph_evaluation_data.sorted_nodes,
476+
graph_evaluation_data.sorted_parameters,
477+
graph_evaluation_data.is_stochastic_vals,
478+
graph_evaluation_data.is_observed_vals,
479+
graph_evaluation_data.node_function_vals,
480+
graph_evaluation_data.loop_vars_vals,
481+
graph_evaluation_data.node_types,
482+
graph_evaluation_data.is_discrete_finite_vals,
483+
minimal_keys,
484+
)
485+
486+
# Return final model with cached minimal keys
487+
return BUGSModel(
488+
model_def,
489+
data,
490+
g,
491+
evaluation_env,
492+
is_transformed,
493+
UseGraph(),
494+
untransformed_param_length,
495+
transformed_param_length,
496+
untransformed_var_lengths,
497+
transformed_var_lengths,
498+
graph_evaluation_data_with_keys,
499+
log_density_computation_function,
500+
mutable_symbols,
501+
nothing,
502+
)
424503
end
425504

426505
## Model interface
@@ -546,7 +625,7 @@ function getparams(model::BUGSModel, evaluation_env=model.evaluation_env)
546625
end
547626
else
548627
(; node_function, loop_vars) = model.g[v]
549-
dist = node_function(evaluation_env, loop_vars)
628+
dist = Base.invokelatest(node_function, evaluation_env, loop_vars)
550629
transformed_value = Bijectors.transform(
551630
Bijectors.bijector(dist), AbstractPPL.get(evaluation_env, v)
552631
)
@@ -572,7 +651,7 @@ function getparams(
572651
d[v] = value
573652
else
574653
(; node_function, loop_vars) = model.g[v]
575-
dist = node_function(evaluation_env, loop_vars)
654+
dist = Base.invokelatest(node_function, evaluation_env, loop_vars)
576655
d[v] = Bijectors.transform(Bijectors.bijector(dist), value)
577656
end
578657
end

JuliaBUGS/src/model/evaluation.jl

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -303,54 +303,73 @@ function _precompute_minimal_cache_keys(model::BUGSModel, order::Vector{Int})
303303
is_discrete_finite = gd.is_discrete_finite_vals
304304
node_types = gd.node_types
305305

306-
# Get stochastic parents for each node
306+
# Get stochastic parents (stochastic boundary) for each node
307307
parents_idx = _get_stochastic_parents_indices(model)
308308

309-
# Initialize frontier keys for each position
310-
minimal_keys = Dict{Int,Vector{Int}}()
309+
# Build mapping from node index (in gd.sorted_nodes) -> position in the provided order.
310+
# This lets us reason about liveness w.r.t. the chosen evaluation order.
311+
order_pos = Vector{Int}(undef, length(gd.sorted_nodes))
312+
@inbounds for k in 1:n
313+
order_pos[order[k]] = k
314+
end
311315

312-
# For each node, determine which discrete finite variables should be in its frontier
313-
for k in 1:n
314-
current_idx = order[k]
315-
316-
# The frontier should include discrete finite variables that:
317-
# 1. Come before this node in the evaluation order (have been set)
318-
# 2. This node depends on (directly or indirectly)
319-
320-
filtered_frontier = Int[]
321-
322-
# For ALL nodes (stochastic and deterministic), we need to track dependencies
323-
# on discrete finite variables that have been set
324-
325-
if node_types[current_idx] == :discrete_finite
326-
# IMPORTANT: At a discrete node, the remainder of computation may include
327-
# observed likelihood terms that depend on ALL earlier discrete assignments
328-
# (e.g., y[1..t-1] in an HMM when all y's are evaluated after all z's).
329-
# Therefore, to avoid incorrect cache reuse, conservatively include ALL
330-
# prior discrete finite unobserved variables in the frontier.
331-
for j in 1:(k - 1)
332-
idx = order[j]
333-
if is_discrete_finite[idx] && !is_observed[idx]
334-
push!(filtered_frontier, idx)
316+
# Compute last-use POSITIONS (w.r.t. 'order') for each unobserved finite-discrete variable.
317+
# A variable stays in the frontier until we pass the last stochastic node
318+
# (observed or unobserved) whose distribution depends on it.
319+
last_use_pos = Dict{Int,Int}() # map from variable index -> last position in 'order'
320+
for j_label in 1:length(gd.sorted_nodes)
321+
if gd.is_stochastic_vals[j_label]
322+
j_pos = order_pos[j_label]
323+
for p_label in parents_idx[j_label]
324+
if is_discrete_finite[p_label] && !is_observed[p_label]
325+
# Default to the position of the variable itself if unseen
326+
default_pos = order_pos[p_label]
327+
last_use_pos[p_label] = max(get(last_use_pos, p_label, default_pos), j_pos)
335328
end
336329
end
337-
else
338-
# For deterministic, continuous, or discrete-infinite nodes, we must ensure
339-
# the memo key distinguishes different earlier discrete assignments that can
340-
# influence any of the remaining computation (e.g., later likelihood terms).
341-
# A conservative but correct choice is to include ALL prior discrete finite
342-
# unobserved variables in the frontier.
343-
for j in 1:(k - 1)
344-
idx = order[j]
345-
if is_discrete_finite[idx] && !is_observed[idx]
346-
push!(filtered_frontier, idx)
347-
end
330+
end
331+
end
332+
333+
# Initialize frontier keys for each position based on liveness
334+
# Optimized incremental construction to avoid O(n^2) in common patterns
335+
minimal_keys = Dict{Int,Vector{Int}}()
336+
337+
# Precompute starts and ends in order positions
338+
starts_at = Dict{Int,Vector{Int}}()
339+
for lbl in 1:length(gd.sorted_nodes)
340+
pos = order_pos[lbl]
341+
if is_discrete_finite[lbl] && !is_observed[lbl]
342+
push!(get!(starts_at, pos, Int[]), lbl)
343+
end
344+
end
345+
346+
# Active set of earlier discrete finite variables (by label index)
347+
active = Int[]
348+
# Track end positions for active labels
349+
function purge_expired!(active_vec::Vector{Int}, k_pos::Int)
350+
# Remove any with last_use_pos < k_pos
351+
i = 1
352+
while i <= length(active_vec)
353+
lbl = active_vec[i]
354+
if get(last_use_pos, lbl, 0) < k_pos
355+
deleteat!(active_vec, i)
356+
else
357+
i += 1
348358
end
349359
end
360+
return active_vec
361+
end
350362

351-
# Store as sorted vector
352-
sort!(unique!(filtered_frontier))
353-
minimal_keys[current_idx] = filtered_frontier
363+
for k in 1:n
364+
# Add labels that start at previous position so they count as "earlier"
365+
if haskey(starts_at, k - 1)
366+
append!(active, starts_at[k - 1])
367+
end
368+
# Drop any labels that have expired before current position
369+
purge_expired!(active, k)
370+
# Sort for stable key representation
371+
sort!(active)
372+
minimal_keys[order[k]] = copy(active)
354373
end
355374

356375
return minimal_keys
@@ -409,7 +428,7 @@ function _marginalize_recursive(
409428

410429
if !is_stochastic
411430
# Deterministic node
412-
value = node_function(env, loop_vars)
431+
value = Base.invokelatest(node_function, env, loop_vars)
413432
new_env = BangBang.setindex!!(env, value, current_vn)
414433
result = _marginalize_recursive(
415434
model,
@@ -424,7 +443,7 @@ function _marginalize_recursive(
424443

425444
elseif is_observed
426445
# Observed stochastic node
427-
dist = node_function(env, loop_vars)
446+
dist = Base.invokelatest(node_function, env, loop_vars)
428447
obs_value = AbstractPPL.get(env, current_vn)
429448
obs_logp = logpdf(dist, obs_value)
430449

@@ -447,7 +466,7 @@ function _marginalize_recursive(
447466

448467
elseif is_discrete_finite
449468
# Discrete finite unobserved node - marginalize out
450-
dist = node_function(env, loop_vars)
469+
dist = Base.invokelatest(node_function, env, loop_vars)
451470
possible_values = enumerate_discrete_values(dist)
452471

453472
logp_branches = Vector{typeof(zero(eltype(parameter_values)))}(
@@ -480,7 +499,7 @@ function _marginalize_recursive(
480499

481500
else
482501
# Continuous or discrete infinite unobserved node - use parameter values
483-
dist = node_function(env, loop_vars)
502+
dist = Base.invokelatest(node_function, env, loop_vars)
484503
b = Bijectors.bijector(dist)
485504

486505
if !haskey(var_lengths, current_vn)
@@ -641,8 +660,12 @@ function evaluate_with_marginalization_values!!(
641660
n = length(model.graph_evaluation_data.sorted_nodes)
642661
sorted_indices = collect(1:n)
643662

644-
# Precompute minimal cache keys
645-
minimal_keys = _precompute_minimal_cache_keys(model, sorted_indices)
663+
# Use precomputed minimal cache keys if available; otherwise compute once for this call
664+
minimal_keys = if !isempty(model.graph_evaluation_data.minimal_cache_keys)
665+
model.graph_evaluation_data.minimal_cache_keys
666+
else
667+
_precompute_minimal_cache_keys(model, sorted_indices)
668+
end
646669

647670
# Initialize memoization cache
648671
# Size hint: at most 2^|discrete_finite| * |nodes| entries

0 commit comments

Comments
 (0)