diff --git a/HISTORY.md b/HISTORY.md index 0f0102ce4..aabb7c8b4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,35 @@ ### Breaking changes +#### Threadsafe evaluation + +DynamicPPL models are by default no longer thread-safe. +If you have threading in a model, you **must** now manually mark it as so, using: + +```julia +@model f() = ... +model = f() +model = setthreadsafe(model, true) +``` + +It used to be that DynamicPPL would 'automatically' enable thread-safe evaluation if Julia was launched with more than one thread (i.e., by checking `Threads.nthreads() > 1`). + +The problem with this approach is that it sacrifices a huge amount of performance. +Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation. + +**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.** +This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros: + + - tilde-statements + - calls to `@addlogprob!` + - any direct manipulation of the special `__varinfo__` variable + +If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe. +**Notably, the following do not require threadsafe evaluation:** + + - Using threading for anything that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation. + - Sampling with `AbstractMCMC.MCMCThreads()`. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. diff --git a/docs/src/api.md b/docs/src/api.md index e81f18dc7..220235eaa 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,6 +42,13 @@ The context of a model can be set using [`contextualize`](@ref): contextualize ``` +Some models require threadsafe evaluation (see https://turinglang.org/docs/THIS_DOESNT_EXIST_YET for more information on when this is necessary). +If this is the case, one must enable threadsafe evaluation for a model: + +```@docs +setthreadsafe +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..e97ff8a98 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -90,6 +90,7 @@ export AbstractVarInfo, Model, getmissings, getargnames, + setthreadsafe, extract_priors, values_as_in_model, # LogDensityFunction diff --git a/src/compiler.jl b/src/compiler.jl index 3324780ca..67f6b7937 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn) modeldef = build_model_definition(expr) # Generate main body - modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn) + modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false) return build_output(modeldef, linenumbernode) end @@ -346,10 +346,11 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +generate_mainbody(mod, expr, warn, warn_threads) = + generate_mainbody!(mod, Symbol[], expr, warn, warn_threads) -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found, x, warn, warn_threads) = x +function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$sym`" push!(found, sym) @@ -357,17 +358,38 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Flag to determine whether we've issued a warning for threadsafe macros Note that this + # detection is not fully correct. We can only detect the presence of a macro that has + # the symbol `Threads.@threads`, however, we can't detect if that *is actually* + # Threads.@threads from Base.Threads. + # Do we don't want escaped expressions because we unfortunately # escape the entire body afterwards. - Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + Meta.isexpr(expr, :escape) && + return generate_mainbody(mod, found, expr.args[1], warn, warn_threads) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + if ( + expr.args[1] == Symbol("@threads") || + expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && + !warn_threads + ) + warn_threads = true + @warn ( + "It looks like you are using `Threads.@threads` in your model definition." * + "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * + " If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." * + "\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/THIS_PAGE_DOESNT_EXIST_YET for more details of when threadsafe evaluation is actually required." + ) + end + return generate_mainbody!( + mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads + ) end # Modify dotted tilde operators. @@ -375,7 +397,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn) if args_dottilde !== nothing L, R = args_dottilde return generate_mainbody!( - mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads ) end @@ -385,8 +407,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end @@ -397,13 +419,16 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_assign return Base.remove_linenums!( generate_assign( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)..., + ) end function generate_assign(left, right) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index e8b50a0b7..40a3cb3c1 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -424,8 +424,11 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a + # check on the merged accumulator, rather than checking it in the accumulate_assume + # calls. That way we can also support multi-threaded evaluation and use `evaluate!!` + # here instead of `_evaluate!!`. + _, varinfo = DynamicPPL._evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..b137ef20d 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -221,23 +221,16 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ) model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 + _, vi = if DynamicPPL._requires_threadsafe(model) accs = map( acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), accs, ) - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + vi_wrapped = ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + _, vi_wrapped = DynamicPPL._evaluate!!(model, vi_wrapped) else - OnlyAccsVarInfo(accs) + DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) end - _, vi = DynamicPPL._evaluate!!(model, vi) return f.getlogdensity(vi) end diff --git a/src/model.jl b/src/model.jl index 2bcfe8f98..8ed08de9d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} @@ -17,6 +17,10 @@ An argument with a type of `Missing` will be in `missings` by default. However, non-traditional use-cases `missings` can be defined differently. All variables in `missings` are treated as random variables rather than observations. +The `Threaded` type parameter indicates whether the model requires threadsafe evaluation +(i.e., whether the model contains statements which modify the internal VarInfo that are +executed in parallel). By default, this is set to `false`. + The default arguments are used internally when constructing instances of the same model with different arguments. @@ -33,8 +37,9 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} @@ -46,13 +51,13 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte Create a model with evaluation function `f` and missing arguments overwritten by `missings`. """ - function Model{missings}( + function Model{missings,Threaded}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Threaded} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Threaded}( f, args, defaults, context ) end @@ -71,6 +76,7 @@ model with different arguments. args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), + threadsafe::Bool=false, ) where {F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing @@ -78,11 +84,25 @@ model with different arguments. missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{$(missing_args..., missing_kwargs...),threadsafe}( + f, args, defaults, context + )) +end + +function Model( + f, + args::NamedTuple, + context::AbstractContext=DefaultContext(), + threadsafe=false; + kwargs..., +) + return Model(f, args, NamedTuple(kwargs), context, threadsafe) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +function _requires_threadsafe( + ::Model{F,A,D,M,Ta,Td,Ctx,Threaded} +) where {F,A,D,M,Ta,Td,Ctx,Threaded} + return Threaded end """ @@ -92,7 +112,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model(model.f, model.args, model.defaults, context, _requires_threadsafe(model)) end """ @@ -105,6 +125,33 @@ function setleafcontext(model::Model, context::AbstractContext) return contextualize(model, setleafcontext(model.context, context)) end +""" + setthreadsafe(model::Model, threadsafe::Bool) + +Returns a new `Model` with its threadsafe flag set to `threadsafe`. + +Threadsafe evaluation ensures correctness when executing model statements that mutate the +internal `VarInfo` object in parallel. For example, this is needed if tilde-statements are +nested inside `Threads.@threads` or similar constructs. + +It is not needed for generic multithreaded operations that don't involve VarInfo. For +example, calculating a log-likelihood term in parallel and then calling `@addlogprob!` +outside of the parallel region is safe without needing to set `threadsafe=true`. + +It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. + +Setting `threadsafe` to `true` increases the overhead in evaluating the model. See +(https://turinglang.org/docs/THIS_DOESNT_EXIST_YET)[https://turinglang.org/docs/THIS_DOESNT_EXIST_YET] +for more details. +""" +function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} + return if _requires_threadsafe(model) == threadsafe + model + else + Model{M,threadsafe}(model.f, model.args, model.defaults, model.context) + end +end + """ model | (x = 1.0, ...) @@ -863,16 +910,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -912,55 +949,26 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. +If the model has been marked as requiring threadsafe evaluation, are available, the varinfo +provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) + return if _requires_threadsafe(model) + wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) else - evaluate_threadunsafe!!(model, varinfo) + _evaluate!!(model, resetaccs!!(varinfo)) end end -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) - return _evaluate!!(model, resetaccs!!(varinfo)) -end - -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) -end - """ _evaluate!!(model::Model, varinfo) diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..9056f666a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,12 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() @@ -793,4 +788,39 @@ module Issue537 end res = model() @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end + + @testset "Threads.@threads detection" begin + # Check that the compiler detects when `Threads.@threads` is used inside a model + + e1 = quote + @model function f1() + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e1) + + e2 = quote + @model function f2() + for j in 1:10 + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e2) + + e3 = quote + @model function f3() + begin + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e3) + end end diff --git a/test/fasteval.jl b/test/fasteval.jl index db2333711..d582649f8 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -77,55 +77,49 @@ end end @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.Experimental.FastLDF(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end + N = 100 + model = setthreadsafe(threaded(zeros(N)), true) + ldf = DynamicPPL.Experimental.FastLDF(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end end @testset "FastLDF: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF( - model, DynamicPPL.getlogjoint_internal, vi - ) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) - @test iszero(bench.allocs) - end + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 522730566..027e51422 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -52,63 +52,24 @@ x[i] ~ Normal(x[i - 1], 1) end end - model = wthreads(x) + model = setthreadsafe(wthreads(x), true) - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo - - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) + function correct_lp(x) + lp = logpdf(Normal(0, 1), x[1]) for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) + lp += logpdf(Normal(x[i - 1], 1), x[i]) end + return lp end - model = wothreads(x) vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("Without `@threads`:") - println(" default:") - @time model(vi) + _, vi = DynamicPPL.evaluate!!(model, vi) - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo + # check that logp is correct + @test getlogjoint(vi) ≈ correct_lp(x) + # check that varinfo was wrapped during the model evaluation + @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end