Skip to content

Commit 7b8cedd

Browse files
committed
fix performance by moving more computation to the construction
1 parent f441205 commit 7b8cedd

File tree

3 files changed

+67
-19
lines changed

3 files changed

+67
-19
lines changed

JuliaBUGS/src/model/abstractppl.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ function _create_modified_model(
569569
# Recompute mutable symbols for the new graph
570570
new_mutable_symbols = get_mutable_symbols(updated_graph_evaluation_data)
571571

572-
# Create the new model with all updated fields
572+
# Create the new model with all updated fields (without auto-marg caches yet)
573573
kwargs = Dict{Symbol,Any}(
574574
:untransformed_param_length => new_untransformed_param_length,
575575
:transformed_param_length => new_transformed_param_length,
@@ -585,7 +585,33 @@ function _create_modified_model(
585585
kwargs[:base_model] = base_model
586586
end
587587

588-
return BUGSModel(model; kwargs...)
588+
new_model = BUGSModel(model; kwargs...)
589+
590+
# Compute and attach auto-marg caches once for the new graph
591+
try
592+
order = JuliaBUGS.Model._compute_marginalization_order(new_model)
593+
keys = JuliaBUGS.Model._precompute_minimal_cache_keys(new_model, order)
594+
595+
gd = new_model.graph_evaluation_data
596+
gd_cached = GraphEvaluationData{
597+
typeof(gd.node_function_vals),typeof(gd.loop_vars_vals)
598+
}(
599+
gd.sorted_nodes,
600+
gd.sorted_parameters,
601+
gd.is_stochastic_vals,
602+
gd.is_observed_vals,
603+
gd.node_function_vals,
604+
gd.loop_vars_vals,
605+
gd.node_types,
606+
gd.is_discrete_finite_vals,
607+
keys,
608+
order,
609+
)
610+
return BUGSModel(new_model; graph_evaluation_data=gd_cached)
611+
catch
612+
# If caches cannot be computed (e.g., unsupported model), return model as-is.
613+
return new_model
614+
end
589615
end
590616

591617
# Common helper function to regenerate log density function

JuliaBUGS/src/model/bugsmodel.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ end
3131
Return the finite support for a discrete univariate distribution.
3232
Relies on Distributions.support to provide an iterable, finite range.
3333
"""
34-
enumerate_discrete_values(dist::Distributions.DiscreteUnivariateDistribution) = Distributions.support(
35-
dist
36-
)
34+
enumerate_discrete_values(dist::Distributions.DiscreteUnivariateDistribution) =
35+
Distributions.support(dist)
3736

3837
"""
3938
classify_node_type(dist)
@@ -77,8 +76,15 @@ struct GraphEvaluationData{TNF,TV}
7776
node_types::Vector{Symbol}
7877
is_discrete_finite_vals::Vector{Bool}
7978
minimal_cache_keys::Dict{Int,Vector{Int}}
79+
marginalization_order::Vector{Int}
8080
end
8181

82+
"""
83+
GraphEvaluationData(compat constructor)
84+
85+
Backward-compatible constructor that fills new caching fields with defaults
86+
when older call sites provide only the first nine fields.
87+
"""
8288
function GraphEvaluationData(
8389
g::BUGSGraph,
8490
sorted_nodes::Vector{<:VarName}=VarName[
@@ -114,7 +120,7 @@ function GraphEvaluationData(
114120
end
115121
end
116122

117-
return GraphEvaluationData(
123+
return GraphEvaluationData{typeof(node_function_vals),typeof(loop_vars_vals)}(
118124
sorted_nodes,
119125
sorted_parameters,
120126
is_stochastic_vals,
@@ -124,6 +130,7 @@ function GraphEvaluationData(
124130
node_types,
125131
is_discrete_finite_vals,
126132
Dict{Int,Vector{Int}}(),
133+
Int[],
127134
)
128135
end
129136

@@ -233,7 +240,7 @@ function BUGSModel(
233240
JuliaBUGS.Model._precompute_minimal_cache_keys(m, collect(1:n))
234241
end
235242
# Attach minimal cache keys to GraphEvaluationData (order remains default)
236-
gd2 = GraphEvaluationData(
243+
gd2 = GraphEvaluationData{typeof(gd.node_function_vals),typeof(gd.loop_vars_vals)}(
237244
gd.sorted_nodes,
238245
gd.sorted_parameters,
239246
gd.is_stochastic_vals,
@@ -243,6 +250,7 @@ function BUGSModel(
243250
gd.node_types,
244251
gd.is_discrete_finite_vals,
245252
minimal_keys,
253+
gd.marginalization_order,
246254
)
247255
# Return final model with cached minimal keys
248256
return BUGSModel(
@@ -368,7 +376,10 @@ function BUGSModel(
368376
end
369377

370378
# Update graph_evaluation_data with the computed node types
371-
graph_evaluation_data = GraphEvaluationData(
379+
graph_evaluation_data = GraphEvaluationData{
380+
typeof(graph_evaluation_data.node_function_vals),
381+
typeof(graph_evaluation_data.loop_vars_vals),
382+
}(
372383
graph_evaluation_data.sorted_nodes,
373384
graph_evaluation_data.sorted_parameters,
374385
graph_evaluation_data.is_stochastic_vals,
@@ -378,6 +389,7 @@ function BUGSModel(
378389
node_types,
379390
is_discrete_finite_vals,
380391
Dict{Int,Vector{Int}}(),
392+
Int[],
381393
)
382394

383395
lowered_model_def, reconstructed_model_def = JuliaBUGS._generate_lowered_model_def(
@@ -429,7 +441,9 @@ function BUGSModel(
429441
end
430442

431443
# Reconstruct GraphEvaluationData while preserving classification
432-
graph_evaluation_data = GraphEvaluationData(
444+
graph_evaluation_data = GraphEvaluationData{
445+
typeof(new_gd.node_function_vals),typeof(new_gd.loop_vars_vals)
446+
}(
433447
new_gd.sorted_nodes,
434448
new_gd.sorted_parameters,
435449
new_gd.is_stochastic_vals,
@@ -439,6 +453,7 @@ function BUGSModel(
439453
new_node_types,
440454
new_is_discrete_finite_vals,
441455
Dict{Int,Vector{Int}}(),
456+
Int[],
442457
)
443458
else
444459
log_density_computation_function = nothing
@@ -464,14 +479,17 @@ function BUGSModel(
464479
mutable_symbols,
465480
nothing,
466481
)
467-
# Precompute minimal cache keys for the default order (1:n)
482+
# Precompute marginalization order and minimal cache keys once
468483
n = length(graph_evaluation_data.sorted_nodes)
469-
sorted_indices = collect(1:n)
484+
sorted_indices = JuliaBUGS.Model._compute_marginalization_order(model_without_min_keys)
470485
minimal_keys = JuliaBUGS.Model._precompute_minimal_cache_keys(
471486
model_without_min_keys, sorted_indices
472487
)
473-
# Attach minimal keys to GraphEvaluationData
474-
graph_evaluation_data_with_keys = GraphEvaluationData(
488+
# Attach cached order and keys to GraphEvaluationData
489+
graph_evaluation_data_with_keys = GraphEvaluationData{
490+
typeof(graph_evaluation_data.node_function_vals),
491+
typeof(graph_evaluation_data.loop_vars_vals),
492+
}(
475493
graph_evaluation_data.sorted_nodes,
476494
graph_evaluation_data.sorted_parameters,
477495
graph_evaluation_data.is_stochastic_vals,
@@ -481,6 +499,7 @@ function BUGSModel(
481499
graph_evaluation_data.node_types,
482500
graph_evaluation_data.is_discrete_finite_vals,
483501
minimal_keys,
502+
sorted_indices,
484503
)
485504

486505
# Return final model with cached minimal keys

JuliaBUGS/src/model/evaluation.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -731,14 +731,17 @@ function evaluate_with_marginalization_values!!(
731731
)
732732
end
733733

734-
# Compute an order that minimizes frontier growth (interleave discrete parents before observed children)
734+
# Use cached marginalization order and minimal frontier keys when available
735735
gd = model.graph_evaluation_data
736736
n = length(gd.sorted_nodes)
737-
sorted_indices = _compute_marginalization_order(model)
738-
739-
# Compute minimal cache keys for this specific order
740-
# (do not reuse cached keys if they were built for a different order)
741-
minimal_keys = _precompute_minimal_cache_keys(model, sorted_indices)
737+
# Strictly require caches to be present for performance
738+
if isempty(gd.marginalization_order) || isempty(gd.minimal_cache_keys)
739+
error(
740+
"Auto marginalization cache missing. This model was not prepared for UseAutoMarginalization.",
741+
)
742+
end
743+
sorted_indices = gd.marginalization_order
744+
minimal_keys = gd.minimal_cache_keys
742745

743746
# Initialize memoization cache
744747
# Size hint: at most 2^|discrete_finite| * |nodes| entries

0 commit comments

Comments
 (0)