Skip to content

Commit 618c23e

Browse files
committed
stop using invokelatest
1 parent e812006 commit 618c23e

File tree

4 files changed

+25
-21
lines changed

4 files changed

+25
-21
lines changed

JuliaBUGS/src/model/abstractppl.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,14 @@ function evaluate!!(
698698
temperature=1.0,
699699
transformed=model.transformed,
700700
)
701-
evaluation_env, log_densities = evaluate_with_values!!(
702-
model, flattened_values; temperature=temperature, transformed=transformed
703-
)
701+
if model.evaluation_mode isa UseAutoMarginalization
702+
evaluation_env, log_densities = evaluate_with_marginalization_values!!(
703+
model, flattened_values; temperature=temperature, transformed=transformed
704+
)
705+
else
706+
evaluation_env, log_densities = evaluate_with_values!!(
707+
model, flattened_values; temperature=temperature, transformed=transformed
708+
)
709+
end
704710
return evaluation_env, log_densities.tempered_logjoint
705711
end

JuliaBUGS/src/model/bugsmodel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ function getparams(model::BUGSModel, evaluation_env=model.evaluation_env)
643643
end
644644
else
645645
(; node_function, loop_vars) = model.g[v]
646-
dist = Base.invokelatest(node_function, evaluation_env, loop_vars)
646+
dist = node_function(evaluation_env, loop_vars)
647647
transformed_value = Bijectors.transform(
648648
Bijectors.bijector(dist), AbstractPPL.get(evaluation_env, v)
649649
)
@@ -679,7 +679,7 @@ function getparams(
679679
d[v] = value
680680
else
681681
(; node_function, loop_vars) = model.g[v]
682-
dist = Base.invokelatest(node_function, evaluation_env, loop_vars)
682+
dist = node_function(evaluation_env, loop_vars)
683683
d[v] = Bijectors.transform(Bijectors.bijector(dist), value)
684684
end
685685
end

JuliaBUGS/src/model/evaluation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ function _marginalize_recursive(
500500

501501
if !is_stochastic
502502
# Deterministic node
503-
value = Base.invokelatest(node_function, env, loop_vars)
503+
value = node_function(env, loop_vars)
504504
new_env = BangBang.setindex!!(env, value, current_vn)
505505
result = _marginalize_recursive(
506506
model,
@@ -515,7 +515,7 @@ function _marginalize_recursive(
515515

516516
elseif is_observed
517517
# Observed stochastic node
518-
dist = Base.invokelatest(node_function, env, loop_vars)
518+
dist = node_function(env, loop_vars)
519519
obs_value = AbstractPPL.get(env, current_vn)
520520
obs_logp = logpdf(dist, obs_value)
521521

@@ -538,7 +538,7 @@ function _marginalize_recursive(
538538

539539
elseif is_discrete_finite
540540
# Discrete finite unobserved node - marginalize out
541-
dist = Base.invokelatest(node_function, env, loop_vars)
541+
dist = node_function(env, loop_vars)
542542
possible_values = enumerate_discrete_values(dist)
543543

544544
logp_branches = Vector{typeof(zero(eltype(parameter_values)))}(
@@ -571,7 +571,7 @@ function _marginalize_recursive(
571571

572572
else
573573
# Continuous or discrete infinite unobserved node - use parameter values
574-
dist = Base.invokelatest(node_function, env, loop_vars)
574+
dist = node_function(env, loop_vars)
575575
b = Bijectors.bijector(dist)
576576

577577
if !haskey(var_lengths, current_vn)

JuliaBUGS/test/model/evaluation.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@ function test_bugs_model_log_density(
1818
transformed_model = compile(model_def, data, inits)
1919
untransformed_model = JuliaBUGS.settrans(transformed_model, false)
2020

21-
# Allow world age to advance by calling the functions in a separate evaluation
22-
Base.invokelatest() do
23-
@test _logjoint(untransformed_model) expected_untransformed rtol = 1E-6
24-
@test _logjoint(transformed_model) expected_transformed rtol = 1E-6
25-
26-
@test LogDensityProblems.logdensity(
27-
transformed_model, JuliaBUGS.getparams(transformed_model)
28-
) expected_transformed rtol = 1E-6
29-
@test LogDensityProblems.logdensity(
30-
untransformed_model, JuliaBUGS.getparams(untransformed_model)
31-
) expected_untransformed rtol = 1E-6
32-
end
21+
# Evaluate directly; model compilation happens before these calls
22+
@test _logjoint(untransformed_model) expected_untransformed rtol = 1E-6
23+
@test _logjoint(transformed_model) expected_transformed rtol = 1E-6
24+
25+
@test LogDensityProblems.logdensity(
26+
transformed_model, JuliaBUGS.getparams(transformed_model)
27+
) expected_transformed rtol = 1E-6
28+
@test LogDensityProblems.logdensity(
29+
untransformed_model, JuliaBUGS.getparams(untransformed_model)
30+
) expected_untransformed rtol = 1E-6
3331
end
3432

3533
@testset "evaluate_with_rng!! - controlling sampling behavior" begin

0 commit comments

Comments
 (0)