Skip to content

Commit 2d2628c

Browse files
committed
fix
1 parent 9940386 commit 2d2628c

File tree

6 files changed

+9
-9
lines changed

6 files changed

+9
-9
lines changed

src/debug_utils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,8 @@ function check_model_and_trace(
451451
issuccess = check_model_pre_evaluation(debug_context, model)
452452

453453
# Force single-threaded execution.
454-
retval, varinfo_result = DynamicPPL.evaluate_threadunsafe!!(
455-
model, varinfo, debug_context
456-
)
454+
debug_model = DynamicPPL.contextualize(model, debug_context)
455+
DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo)
457456

458457
# Perform checks after evaluating the model.
459458
issuccess &= check_model_post_evaluation(debug_context, model)

src/model.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,8 @@ end
929929
Evaluate the `model` with the given `varinfo`. If an additional `context` is provided,
930930
the model's context is combined with that context.
931931
932-
This function does not wrap the varinfo in a `ThreadSafeVarInfo`.
932+
This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not
933+
reset the log probability of the `varinfo` before running.
933934
"""
934935
function _evaluate!!(model::Model, varinfo::AbstractVarInfo)
935936
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo)

src/sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function AbstractMCMC.step(
5858
kwargs...,
5959
)
6060
vi = VarInfo()
61-
model(rng, vi, sampler)
61+
DynamicPPL.sample!!(rng, model, vi, sampler)
6262
return vi, nothing
6363
end
6464

test/ad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ using DynamicPPL: LogDensityFunction
110110
# Compiling the ReverseDiff tape used to fail here
111111
spl = Sampler(MyEmptyAlg())
112112
vi = VarInfo(model)
113-
sampling_model = contextualize(model, SamplingConext(model.context))
113+
sampling_model = contextualize(model, SamplingContext(model.context))
114114
ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true))
115115
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
116116
end

test/context_implementations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
μ ~ MvNormal(zeros(2), 4 * I)
66
z = Vector{Int}(undef, length(x))
77
z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x))))
8-
for i in 1:length(x)
8+
for i in eachindex(x)
99
x[i] ~ Normal(μ[z[i]], 0.1)
1010
end
1111
end
1212

13-
test([1, 1, -1])(VarInfo(), SampleFromPrior(), DefaultContext())
13+
test([1, 1, -1])(VarInfo())
1414
end
1515

1616
@testset "dot tilde with varying sizes" begin

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
332332
@test logjoint(model, x) !=
333333
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...)
334334
# Ensure `varnames` is implemented.
335-
vi = last(DynamicPPL.sample!!(sampling_model, SimpleVarInfo(OrderedDict())))
335+
vi = last(DynamicPPL.sample!!(model, SimpleVarInfo(OrderedDict())))
336336
@test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model))
337337
# Ensure `posterior_mean` is implemented.
338338
@test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x)

0 commit comments

Comments
 (0)