Skip to content

Commit 61ed22d

Browse files
authored
Add an argument to condition function to skip generating and evaling log density function (#394)
## Motivation `condition` currently regenerates the log-density function on every call, invoking MacroTools transforms and `eval`. This is too slow for hot loops. We want a fast path for repeated conditioning that: - Avoids regeneration - Lets users update observed values in-place without changing graph structure - Allows explicit regeneration later to regain compiled performance when desired ## Summary of Changes - `condition` now supports a fast path via a keyword flag: - `condition(model, conditioning_spec; regenerate_log_density::Bool=true)` - When `false`: skips regeneration, nulls the compiled function, and forces graph evaluation mode - New helper for value updates without reconditioning: - `set_observed_values!(model, obs::Dict{<:VarName,<:Any})` - Updates evaluation environment values of already-observed stochastic variables in-place - Validates existence, stochasticity, and observation status - New function to regenerate the compiled log-density function without changing the evaluation mode: - `regenerate_log_density_function(model; force=false)` - Re-generates for the model’s current graph + environment and refreshes graph evaluation data - Exports added: - `set_observed_values!` and `regenerate_log_density_function` (from `Model` module) ## Tiny Benchmark for Sanity Check with ```julia @bugs begin for i in 1:N x[i] ~ Normal(0, 1) end y ~ Normal(sum(x[:]), 1) end ``` ``` BenchmarkTools results (N=100, iters=200): Single condition: regular = 13.96 ms, fast = 0.15 ms, speedup ≈ 92.3x Loop conditioning: regular = 28139.5 ms, fast(update only) = 0.2 ms, speedup ≈ 112840.1x ``` Interpretation: - Single call to `condition`: ~90x faster when skipping regeneration - Hot loop: one fast `condition` + repeated `set_observed_values!` is ~1e5x faster than repeatedly reconditioning
1 parent 583f971 commit 61ed22d

File tree

4 files changed

+184
-6
lines changed

4 files changed

+184
-6
lines changed

JuliaBUGS/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.10.1"
3+
version = "0.10.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

JuliaBUGS/src/model/Model.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include("abstractppl.jl")
1919
include("logdensityproblems.jl")
2020

2121
export parameters, variables, initialize!, getparams, settrans, set_evaluation_mode
22+
export regenerate_log_density_function, set_observed_values!
2223
export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!!
2324

2425
end # Model

JuliaBUGS/src/model/abstractppl.jl

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import AbstractPPL: condition, decondition, evaluate!!
2020
#######################
2121

2222
"""
23-
condition(model::BUGSModel, conditioning_spec)
23+
condition(model::BUGSModel, conditioning_spec; regenerate_log_density::Bool=true)
2424
2525
Create a new model by conditioning on specified variables with given values.
2626
@@ -135,7 +135,7 @@ julia> parameters(model_cond4)
135135
y
136136
```
137137
"""
138-
function condition(model::BUGSModel, conditioning_spec)
138+
function condition(model::BUGSModel, conditioning_spec; regenerate_log_density::Bool=true)
139139
# Parse and validate conditioning specification
140140
var_values = _parse_conditioning_spec(conditioning_spec, model)::Dict{<:VarName,<:Any}
141141
vars_to_condition = collect(keys(var_values))::Vector{<:VarName}
@@ -157,6 +157,7 @@ function condition(model::BUGSModel, conditioning_spec)
157157
new_graph,
158158
new_evaluation_env;
159159
base_model=isnothing(model.base_model) ? model : model.base_model,
160+
regenerate_log_density=regenerate_log_density,
160161
)
161162
end
162163

@@ -551,6 +552,7 @@ function _create_modified_model(
551552
new_graph::BUGSGraph,
552553
new_evaluation_env::NamedTuple;
553554
base_model=nothing,
555+
regenerate_log_density::Bool=true,
554556
)
555557
# Create new graph evaluation data
556558
new_graph_evaluation_data = GraphEvaluationData(new_graph)
@@ -562,9 +564,15 @@ function _create_modified_model(
562564
)
563565

564566
# Generate new log density function and update graph evaluation data
565-
new_log_density_computation_function, updated_graph_evaluation_data = _regenerate_log_density_function(
566-
model.model_def, new_graph, new_evaluation_env, new_graph_evaluation_data
567-
)
567+
new_log_density_computation_function, updated_graph_evaluation_data =
568+
if regenerate_log_density
569+
_regenerate_log_density_function(
570+
model.model_def, new_graph, new_evaluation_env, new_graph_evaluation_data
571+
)
572+
else
573+
# Skip regeneration (fast path): ensure stale code isn't used
574+
nothing, new_graph_evaluation_data
575+
end
568576

569577
# Recompute mutable symbols for the new graph
570578
new_mutable_symbols = get_mutable_symbols(updated_graph_evaluation_data)
@@ -580,6 +588,12 @@ function _create_modified_model(
580588
:mutable_symbols => new_mutable_symbols,
581589
)
582590

591+
# Force graph evaluation mode when skipping regeneration to avoid stale compiled code
592+
if !regenerate_log_density
593+
kwargs[:evaluation_mode] = UseGraph()
594+
kwargs[:log_density_computation_function] = nothing
595+
end
596+
583597
# Add base_model if provided
584598
if !isnothing(base_model)
585599
kwargs[:base_model] = base_model
@@ -623,6 +637,62 @@ function _regenerate_log_density_function(
623637
end
624638
end
625639

640+
#######################
641+
# Observed Value Updates
642+
#######################
643+
644+
"""
645+
set_observed_values!(model::BUGSModel, obs::Dict{<:VarName,<:Any})
646+
647+
Update values of observed stochastic variables without reconditioning or regenerating code.
648+
649+
Validates that each variable exists in the model, is stochastic, and is currently observed.
650+
Updates the evaluation environment in place and returns the updated model.
651+
"""
652+
function set_observed_values!(model::BUGSModel, obs::Dict{<:VarName,<:Any})
653+
new_env = model.evaluation_env
654+
for (vn, val) in obs
655+
if vn labels(model.g)
656+
throw(ArgumentError("Variable $vn does not exist in the model"))
657+
end
658+
node_info = model.g[vn]
659+
if !node_info.is_stochastic
660+
throw(ArgumentError("Cannot update $vn: it is deterministic (logical)"))
661+
end
662+
if !node_info.is_observed
663+
throw(ArgumentError("Cannot update $vn: it is not observed"))
664+
end
665+
new_env = BangBang.setindex!!(new_env, val, vn)
666+
end
667+
return BangBang.setproperty!!(model, :evaluation_env, new_env)
668+
end
669+
670+
"""
671+
regenerate_log_density_function(model::BUGSModel; force::Bool=false)
672+
673+
Generate and attach a compiled log-density function for the model's current graph and evaluation environment.
674+
675+
Does not change the evaluation mode. When `force=false`, preserves an existing compiled function; when `force=true`,
676+
overwrites it if a new one can be generated. Returns the updated model (or the original if generation is not possible).
677+
"""
678+
function regenerate_log_density_function(model::BUGSModel; force::Bool=false)
679+
new_fn, updated_graph_eval_data = _regenerate_log_density_function(
680+
model.model_def, model.g, model.evaluation_env, model.graph_evaluation_data
681+
)
682+
# Always refresh graph_evaluation_data from regeneration helper (it may refine ordering)
683+
model = BangBang.setproperty!!(model, :graph_evaluation_data, updated_graph_eval_data)
684+
685+
if isnothing(new_fn)
686+
# Cannot generate compiled function; leave as-is
687+
return model
688+
end
689+
690+
if force || isnothing(model.log_density_computation_function)
691+
model = BangBang.setproperty!!(model, :log_density_computation_function, new_fn)
692+
end
693+
return model
694+
end
695+
626696
#######################
627697
# Evaluation API
628698
#######################

JuliaBUGS/test/model/abstractppl.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using JuliaBUGS.Model:
55
decondition,
66
parameters,
77
set_evaluation_mode,
8+
set_observed_values!,
9+
regenerate_log_density_function,
810
UseGeneratedLogDensityFunction,
911
UseGraph
1012
using LogDensityProblems
@@ -149,6 +151,111 @@ JuliaBUGS.@bugs_primitive Normal Gamma
149151
logp = LogDensityProblems.logdensity(model_cond_gen, Float64[])
150152
@test isfinite(logp)
151153
end
154+
155+
@testset "Fast conditioning path and observed value updates" begin
156+
@testset "Fast conditioning skips regeneration and forces graph mode" begin
157+
model_def = @bugs begin
158+
x ~ Normal(0, 1)
159+
y ~ Normal(x, 1)
160+
end
161+
162+
model = compile(model_def, (;))
163+
164+
# Regular conditioning (default regenerates)
165+
model_reg = condition(model, Dict(@varname(x) => 1.0))
166+
@test !isnothing(model_reg.log_density_computation_function)
167+
168+
# Fast conditioning (no regeneration)
169+
model_fast = condition(
170+
model, Dict(@varname(x) => 1.0); regenerate_log_density=false
171+
)
172+
@test model_fast.log_density_computation_function === nothing
173+
@test model_fast.evaluation_mode isa UseGraph
174+
175+
# Same parameters
176+
@test parameters(model_reg) == parameters(model_fast)
177+
178+
# Compare log density via graph evaluation
179+
params = zeros(length(parameters(model_fast)))
180+
logp_fast = LogDensityProblems.logdensity(
181+
set_evaluation_mode(model_fast, UseGraph()), params
182+
)
183+
logp_reg = LogDensityProblems.logdensity(
184+
set_evaluation_mode(model_reg, UseGraph()), params
185+
)
186+
@test logp_fast logp_reg
187+
end
188+
189+
@testset "set_observed_values! updates values and validates" begin
190+
# Model with a deterministic node to hit validation
191+
model_def = @bugs begin
192+
x ~ Normal(0, 1)
193+
y = x^2 # deterministic
194+
z ~ Normal(y, 1)
195+
end
196+
197+
model = compile(model_def, (; z=2.0))
198+
199+
# Fast condition on x
200+
m = condition(model, Dict(@varname(x) => 1.0); regenerate_log_density=false)
201+
@test m.evaluation_mode isa UseGraph
202+
203+
# Update observed x value without reconditioning
204+
m2 = set_observed_values!(m, Dict(@varname(x) => 2.0))
205+
@test m2.evaluation_env.x == 2.0
206+
207+
# Structure unchanged (x observed; only z observed from data; no parameters)
208+
# Here parameters is empty because y is deterministic and z is observed
209+
@test parameters(m2) == parameters(m)
210+
211+
# Errors on invalid updates
212+
@test_throws ArgumentError set_observed_values!(
213+
m2, Dict(@varname(y) => 3.0)
214+
) # deterministic
215+
216+
# Updating originally observed data should be allowed
217+
m3 = set_observed_values!(m2, Dict(@varname(z) => 1.0))
218+
@test m3.evaluation_env.z == 1.0
219+
# To test non-observed error, try updating a parameter in a different model.
220+
end
221+
222+
@testset "set_observed_values! errors on non-observed variables" begin
223+
model_def = @bugs begin
224+
x ~ Normal(0, 1)
225+
y ~ Normal(x, 1)
226+
end
227+
model = compile(model_def, (;))
228+
m = condition(model, Dict(@varname(x) => 1.0); regenerate_log_density=false)
229+
# y is not observed
230+
@test_throws ArgumentError set_observed_values!(m, Dict(@varname(y) => 0.0))
231+
end
232+
233+
@testset "Regeneration after fast conditioning (no mode change)" begin
234+
model_def = @bugs begin
235+
x ~ Normal(0, 1)
236+
y ~ Normal(x, 1)
237+
end
238+
239+
model = compile(model_def, (;))
240+
m = condition(model, Dict(@varname(x) => 1.0); regenerate_log_density=false)
241+
@test m.log_density_computation_function === nothing
242+
243+
# Regenerate compiled function without changing mode
244+
m2 = regenerate_log_density_function(m)
245+
@test !isnothing(m2.log_density_computation_function)
246+
@test m2.evaluation_mode isa UseGraph
247+
248+
# Can switch to generated mode explicitly and match graph
249+
params = zeros(length(parameters(m2)))
250+
logp_gen = LogDensityProblems.logdensity(
251+
set_evaluation_mode(m2, UseGeneratedLogDensityFunction()), params
252+
)
253+
logp_graph = LogDensityProblems.logdensity(
254+
set_evaluation_mode(m2, UseGraph()), params
255+
)
256+
@test logp_gen logp_graph
257+
end
258+
end
152259
end
153260

154261
@testset "decondition" begin

0 commit comments

Comments
 (0)