diff --git a/HISTORY.md b/HISTORY.md index ddbe67842..309bbe011 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,41 @@ ## 0.38.0 +**Breaking changes** + +foo + +**Other changes** + +### Thread-safe execution + +This release removes `ThreadSafeVarInfo`, which was a construction used to ensure thread-safe accumulation of log-likelihood terms using the `Threads.@threads`. +However, `Threads.@threads` is no longer the recommended way to perform multithreaded tasks: see e.g. [this Julia blog post](https://julialang.org/blog/2023/07/PSA-dont-use-threadid/). + +In its place a new macro, `@pobserve` is introduced, which under the hood uses `Threads.@spawn`. +**From a user's point of view you simply need to replace `Threads.@threads` with `@pobserve`.** +For example, here the likelihood contributions for each element of `y` are calculated in parallel: + +```julia +@model function f(y) + mu ~ Normal() + yplusones = @pobserve for i in eachindex(y) + y[i] ~ Normal(mu) + return y[i] + 1 + end +end +``` + +Furthermore, the `@pobserve` block will also return the final value inside the block, so can also be used to parallelise arbitrary computation. In the model above, `yplusones` will be a vector of length `y` where each element is `y[i] + 1`. + +Please note that this only works for **likelihood terms**, i.e., observed variables (hence the macro name). +It is a long-term goal to be able to parallelise other parts of model execution such as the sampling of new variables, but this is not presently possible. + +`@pobserve` is also not currently compatible with Turing's particle samplers (because Libtask.jl does not work with `Threads.@spawn)`. +This is, in fact, a good thing, because the previous behaviour of particle samplers with `Threads.@threads` was to silently give a wrong result. + +### Other minor changes + The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. Their behaviour is otherwise identical. diff --git a/docs/src/api.md b/docs/src/api.md index d1dddb560..f424c7836 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -160,6 +160,12 @@ It is possible to manually increase (or decrease) the accumulated log likelihood @addlogprob! ``` +If you want to perform observations in parallel (using Julia threads), you can use the following macro. + +```@docs +@pobserve +``` + Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`. ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6a01884a9..22011251b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -133,6 +133,7 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, + @pobserve, value_iterator_from_chain, check_model, check_model_and_trace, @@ -186,11 +187,11 @@ include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") -include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") +include("pobserve_macro.jl") include("pointwise_logdensities.jl") include("transforming.jl") include("logdensityfunction.jl") diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..326850fdf 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -135,7 +135,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, + vi::VarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. diff --git a/src/debug_utils.jl b/src/debug_utils.jl index c2be4b46b..19b88ec3f 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -425,8 +425,7 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + _, varinfo = DynamicPPL.evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/model.jl b/src/model.jl index a6a3e0685..22f8d5b21 100644 --- a/src/model.jl +++ b/src/model.jl @@ -853,16 +853,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -903,62 +893,19 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. - Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) - else - evaluate_threadunsafe!!(model, varinfo) - end -end - -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) return _evaluate!!(model, resetaccs!!(varinfo)) end -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) -end - """ _evaluate!!(model::Model, varinfo) Evaluate the `model` with the given `varinfo`. -This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not -reset the log probability of the `varinfo` before running. +This function does not reset the accumulators in the `varinfo` before running. """ function _evaluate!!(model::Model, varinfo::AbstractVarInfo) args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) diff --git a/src/pobserve_macro.jl b/src/pobserve_macro.jl new file mode 100644 index 000000000..b964e1de4 --- /dev/null +++ b/src/pobserve_macro.jl @@ -0,0 +1,90 @@ +using MacroTools: @capture, @q + +""" + @pobserve + +Perform observations in parallel. +""" +macro pobserve(expr) + return _pobserve(expr) +end + +function _pobserve(expr::Expr) + @capture( + expr, + for ctr_ in iterable_ + block_ + end + ) || error("expected for loop") + # reconstruct the for loop with the processed block + return_expr = @q begin + likelihood_tasks = map($(esc(iterable))) do $(esc(ctr)) + Threads.@spawn begin + $(process_tilde_statements(block)) + end + end + retvals_and_likelihoods = fetch.(likelihood_tasks) + total_likelihoods = sum(last, retvals_and_likelihoods) + if $(DynamicPPL.hasacc)($(esc(:(__varinfo__))), Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)( + $(esc(:(__varinfo__))), total_likelihoods + ) + end + map(first, retvals_and_likelihoods) + end + return return_expr +end + +""" + process_tilde_statements(expr) + +This function traverses a block expression `expr` and transforms any +lines in it that look like `lhs ~ rhs` into a simple accumulation of +likelihoods, i.e., `Distributions.logpdf(rhs, lhs)`. +""" +function process_tilde_statements(expr::Expr) + @capture( + expr, + begin + statements__ + end + ) || error("expected block") + @gensym loglike + beginning_expr = :( + $loglike = if $(DynamicPPL.hasacc)($(esc(:(__varinfo__))), Val(:LogLikelihood)) + zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__))))) + else + zero($(DynamicPPL.LogProbType)) + end + ) + n_statements = length(statements) + transformed_statements::Vector{Vector{Expr}} = map(enumerate(statements)) do (i, stmt) + is_last = i == n_statements + if @capture(stmt, lhs_ ~ rhs_) + # TODO: We should probably perform some checks to make sure that this + # indeed was meant to be an observe statement. + @gensym left + e = [ + :($left = $(esc(lhs))), + :($loglike += $(Distributions.logpdf)($(esc(rhs)), $left)), + ] + is_last && push!(e, :(($left, $loglike))) + e + elseif @capture(stmt, lhs_ .~ rhs_) + @gensym val + e = [ + # TODO: dot-tilde + :($val = $(esc(stmt))), + ] + is_last && push!(e, :(($val, $loglike))) + e + else + @gensym val + e = [:($val = $(esc(stmt)))] + is_last && push!(e, :(($val, $loglike))) + e + end + end + new_statements = [beginning_expr, reduce(vcat, transformed_statements)...] + return Expr(:block, new_statements...) +end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 27365e4dc..bdf36a750 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -411,12 +411,8 @@ function BangBang.push!!( return Accessors.@set vi.values = setindex!!(vi.values, value, vn) end -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - # Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V +Base.eltype(::SimpleVarInfo{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) @@ -474,7 +470,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::SimpleOrThreadSafeSimple, + vi::SimpleVarInfo, ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. @@ -485,16 +481,13 @@ function assume( return value, vi end -function settrans!!(vi::SimpleVarInfo, trans) +function settrans!!(vi::SimpleVarInfo, trans::Bool) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Accessors.@set vi.transformation = transformation end -function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) -end -function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) +function settrans!!(vi::SimpleVarInfo, trans::Bool, ::VarName) # We keep this method around just to obey the AbstractVarInfo interface. # However, note that this would only be a valid operation if it would be a # no-op, which we check here. @@ -507,8 +500,6 @@ end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 26e2aa7ca..e3026ba6c 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -15,17 +15,13 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal end """ - setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false) + setup_varinfos(model::Model, example_values::NamedTuple, varnames) Return a tuple of instances for different implementations of `AbstractVarInfo` with each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`. -If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions -of the varinfo instances. """ -function setup_varinfos( - model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false -) +function setup_varinfos(model::Model, example_values::NamedTuple, varnames) # VarInfo vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) @@ -51,9 +47,5 @@ function setup_varinfos( last(DynamicPPL.evaluate!!(model, vi)) end - if include_threadsafe - varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) - end - return varinfos end diff --git a/src/threadsafe.jl b/src/threadsafe.jl deleted file mode 100644 index 6ca3b9852..000000000 --- a/src/threadsafe.jl +++ /dev/null @@ -1,236 +0,0 @@ -""" - ThreadSafeVarInfo - -A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of accumulators for thread-safe execution of a probabilistic model. -""" -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo - varinfo::V - accs_by_thread::Vector{L} -end -function ThreadSafeVarInfo(vi::AbstractVarInfo) - # In ThreadSafeVarInfo we use threadid() to index into the array of logp - # fields. This is not good practice --- see - # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full - # explanation --- but it has worked okay so far. - # The use of nthreads()*2 here ensures that threadid() doesn't exceed - # the length of the logps array. Ideally, we would use maxthreadid(), - # but Mooncake can't differentiate through that. Empirically, nthreads()*2 - # seems to provide an upper bound to maxthreadid(), so we use that here. - # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] - return ThreadSafeVarInfo(vi, accs_by_thread) -end -ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi - -transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) - -# Set the accumulator in question in vi.varinfo, and set the thread-specific -# accumulators of the same type to be empty. -function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) - inner_vi = setacc!!(vi.varinfo, acc) - news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) - return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) -end - -# Get both the main accumulator and the thread-specific accumulators of the same type and -# combine them. -function getacc(vi::ThreadSafeVarInfo, accname::Val) - main_acc = getacc(vi.varinfo, accname) - other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) - return foldl(combine, other_accs; init=main_acc) -end - -hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) -acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) - -function getaccs(vi::ThreadSafeVarInfo) - # This method is a bit finicky to maintain type stability. For instance, moving the - # accname -> Val(accname) part in the main `map` call makes constant propagation fail - # and this becomes unstable. Do check the effects if you make edits. - accnames = acckeys(vi) - accname_vals = map(Val, accnames) - return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) -end - -# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that -# should _not_ be thread-specific a specific method has to be written. -function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) - tid = Threads.threadid() - vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) - return vi -end - -function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) - tid = Threads.threadid() - vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) - return vi -end - -has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) - -function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) -end - -syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) - -setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) - -keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) -haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) - -islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo) - -function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) -end - -function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) -end - -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) -end - -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) -end - -# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. -# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates -# to define `getacc(vi)`. -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{false}()) - ) - return settrans!!(last(evaluate!!(model, vi)), t) -end - -function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{true}()) - ) - return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) -end - -function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end - -# These two StaticTransformation methods needed to resolve ambiguities -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model) -end - -function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) - # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the - # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogprior(vi)`. - return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) -end - -# `getindex` -getindex(vi::ThreadSafeVarInfo, ::Colon) = getindex(vi.varinfo, Colon()) -getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) -getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = getindex(vi.varinfo, vns) -function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) - return getindex(vi.varinfo, vn, dist) -end -function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) - return getindex(vi.varinfo, vns, dist) -end - -function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:VarName}) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) -end - -vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) -vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn) -function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) - return vector_getranges(vi.varinfo, vns) -end - -isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) -end - -function resetaccs!!(vi::ThreadSafeVarInfo) - vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo) - for i in eachindex(vi.accs_by_thread) - vi.accs_by_thread[i] = map(reset, vi.accs_by_thread[i]) - end - return vi -end - -values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) -values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) - -function unset_flag!( - vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false -) - return unset_flag!(vi.varinfo, vn, flag, ignoreable) -end -function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return is_flagged(vi.varinfo, vn, flag) -end - -function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) -end - -istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) - -getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) - -function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) -end - -function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) - return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) -end - -function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo) - return Accessors.@set varinfo_left.varinfo = merge( - varinfo_left.varinfo, varinfo_right.varinfo - ) -end - -function invlink_with_logpdf(vi::ThreadSafeVarInfo, vn::VarName, dist, y) - return invlink_with_logpdf(vi.varinfo, vn, dist, y) -end - -function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) - return from_internal_transform(varinfo.varinfo, vn) -end -function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) - return from_internal_transform(varinfo.varinfo, vn, dist) -end - -function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) - return from_linked_internal_transform(varinfo.varinfo, vn) -end -function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) - return from_linked_internal_transform(varinfo.varinfo, vn, dist) -end diff --git a/src/varinfo.jl b/src/varinfo.jl index 081f65ea1..6311be9e0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -152,9 +152,6 @@ const UntypedVarInfo = VarInfo{<:Metadata} # something which carried both its keys as well as its values' types as type # parameters. const NTVarInfo = VarInfo{<:NamedTuple} -const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ - VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} -} function Base.:(==)(vi1::VarInfo, vi2::VarInfo) return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) @@ -382,6 +379,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # The below line is finicky for type stability. For instance, assigning the eltype to # convert to into an intermediate variable makes this unstable (constant propagation) # fails. Take care when editing. + # TODO(penelopeysm): Can this be simplified if TSVI is gone? accs = map( acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) ) @@ -962,12 +960,6 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) return vi end -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) -end - function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -975,17 +967,6 @@ function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model:: return vi end -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) -end - function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) @@ -1067,12 +1048,6 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) return vi end -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end - function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1080,17 +1055,6 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, mode return vi end -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) -end - function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do @@ -1180,27 +1144,10 @@ function link(::DynamicTransformation, varinfo::VarInfo, model::Model) return _link(model, varinfo, keys(varinfo)) end -function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) -end - function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _link(model, varinfo, vns) end -function link( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) -end - function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) @@ -1344,29 +1291,10 @@ function invlink(::DynamicTransformation, vi::VarInfo, model::Model) return _invlink(model, vi, keys(vi)) end -function invlink( - ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) -end - function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _invlink(model, varinfo, vns) end -function invlink( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) -end - function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) @@ -1814,7 +1742,7 @@ end Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. """ -function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) +function _apply!(kernel!, vi::VarInfo, values, keys) keys_strings = map(string, collect_maybe(keys)) num_indices_seen = 0 @@ -1872,7 +1800,7 @@ end end end -function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) +function _find_missing_keys(vi::VarInfo, keys) string_vns = map(string, collect_maybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key @@ -1937,7 +1865,7 @@ function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..3c451b6b0 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,12 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() diff --git a/test/model.jl b/test/model.jl index 7374f73aa..ffc4c23fe 100644 --- a/test/model.jl +++ b/test/model.jl @@ -142,19 +142,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end - @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin - @model function multiple_types(x) - ns ~ filldist(Normal(0, 2.0), 3) - m ~ Uniform(0, 1) - return x ~ Normal(m, 1) - end - model = multiple_types(1) - chain = make_chain_from_prior(model, 10) - loglikelihood(model, chain) - logprior(model, chain) - logjoint(model, chain) - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT diff --git a/test/pobserve_macro.jl b/test/pobserve_macro.jl new file mode 100644 index 000000000..ff523c3b5 --- /dev/null +++ b/test/pobserve_macro.jl @@ -0,0 +1,74 @@ +module DynamicPPLPobserveMacroTests + +using DynamicPPL, Distributions, Test + +@testset verbose = true "pobserve_macro.jl" begin + @testset "loglikelihood is correctly accumulated" begin + @model function f(x) + @pobserve for i in eachindex(x) + x[i] ~ Normal() + end + end + x = randn(3) + expected_loglike = loglikelihood(Normal(), x) + vi = VarInfo(f(x)) + @test isapprox(DynamicPPL.getloglikelihood(vi), expected_loglike) + end + + @testset "doesn't error when varinfo has no likelihood acc" begin + @model function f(x) + @pobserve for i in eachindex(x) + x[i] ~ Normal() + end + end + x = randn(3) + vi = VarInfo() + vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),)) + @test DynamicPPL.evaluate!!(f(x), vi) isa Any + end + + @testset "return values are correct" begin + @testset "single expression at the end" begin + @model function f(x) + xplusone = @pobserve for i in eachindex(x) + x[i] ~ Normal() + x[i] + 1 + end + return xplusone + end + x = randn(3) + @test f(x)() == x .+ 1 + + @testset "calculations are not repeated" begin + # Make sure that the final expression inside pobserve is not evaluated + # multiple times. + counter = 0 + increment_and_return(y) = (counter += 1; y) + @model function g(x) + xs = @pobserve for i in eachindex(x) + x[i] ~ Normal() + increment_and_return(x[i]) + end + return xs + end + x = randn(3) + @test g(x)() == x + @test counter == length(x) + end + end + + @testset "tilde expression at the end" begin + @model function f(x) + xs = @pobserve for i in eachindex(x) + # This should behave as if it returns `x[i]` + x[i] ~ Normal() + end + return xs + end + x = randn(3) + @test f(x)() == x + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index c60c06786..ef417350f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ include("test_util.jl") include("utils.jl") include("accumulators.jl") include("compiler.jl") + include("pobserve_macro.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") @@ -70,7 +71,6 @@ include("test_util.jl") include("lkj.jl") include("contexts.jl") include("context_implementations.jl") - include("threadsafe.jl") include("debug_utils.jl") include("submodels.jl") include("bijector.jl") diff --git a/test/test_util.jl b/test/test_util.jl index 164751c7b..ab2a80dc0 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -13,9 +13,6 @@ const gdemo_default = gdemo_d() Return string representing a short description of `vi`. """ -function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) - return "threadsafe($(short_varinfo_name(vi.varinfo)))" -end function short_varinfo_name(vi::DynamicPPL.NTVarInfo) return if DynamicPPL.has_varnamedvector(vi) "TypedVectorVarInfo" diff --git a/test/threadsafe.jl b/test/threadsafe.jl deleted file mode 100644 index 0421c89e2..000000000 --- a/test/threadsafe.jl +++ /dev/null @@ -1,116 +0,0 @@ -@testset "threadsafe.jl" begin - @testset "constructor" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) - - @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} - @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - end - - # TODO: Add more tests of the public API - @testset "API" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - - lp = getlogjoint(vi) - @test getlogjoint(threadsafe_vi) == lp - - threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) - @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 - @test getlogjoint(vi) == lp - @test getlogjoint(threadsafe_vi) == lp + 42 - - threadsafe_vi = DynamicPPL.resetaccs!!(threadsafe_vi) - @test iszero(getlogjoint(threadsafe_vi)) - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - - threadsafe_vi = setlogprior!!(threadsafe_vi, 42) - @test getlogjoint(threadsafe_vi) == 42 - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - end - - @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - - x = rand(10_000) - - @model function wthreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - Threads.@threads for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - model = wthreads(x) - - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo - - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - model = wothreads(x) - - vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("Without `@threads`:") - println(" default:") - @time model(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo - @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) - end -end diff --git a/test/varinfo.jl b/test/varinfo.jl index 75d8e062b..745697315 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,5 +1,5 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleVarInfo{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, # since `keys(varinfo_merged)` only contains `VarName` with `identity`. # So we just check that the original keys are present. @@ -600,9 +600,7 @@ end vns = DynamicPPL.TestUtils.varnames(model) # Set up the different instances of `AbstractVarInfo` with the desired values. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, example_values, vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) @@ -645,11 +643,9 @@ end @testset "mutating=$mutating" for mutating in [false, true] value_true = DynamicPPL.TestUtils.rand_prior_true(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleVarInfo{<:NamedTuple} # NOTE: this is broken since we'll end up trying to set # # varinfo[@varname(x[4:5])] = [x[4],] @@ -722,14 +718,11 @@ end end model = demo(0.0) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, (; x=1.0), (@varname(x),); include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, (; x=1.0), (@varname(x),)) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # Skip the inconcrete `SimpleVarInfo` types, since checking for type # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} + if varinfo isa SimpleVarInfo{<:AbstractDict} continue end @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) @@ -749,13 +742,9 @@ end vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] # `VarInfo` supports, effectively, arbitrary subsetting. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, model(), vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, model(), vns) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) + varinfos_simple = filter(Base.Fix2(isa, DynamicPPL.SimpleVarInfo), varinfos) # `VarInfo` supports subsetting using, basically, arbitrary varnames. vns_supported_standard = [ @@ -795,8 +784,7 @@ end # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, ## i.e. `VarName{sym}()` without any indexing, etc. vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple + if varinfo isa DynamicPPL.SimpleVarInfo && values_as(varinfo) isa NamedTuple vns_supported_simple else vns_supported_standard @@ -868,10 +856,7 @@ end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, - DynamicPPL.TestUtils.rand_prior_true(model), - vns; - include_threadsafe=true, + model, DynamicPPL.TestUtils.rand_prior_true(model), vns ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @testset "with itself" begin @@ -965,13 +950,9 @@ end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) nt = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, nt, vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns) # Only keep `VarInfo` types. - varinfos = filter( - Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos - ) + varinfos = filter(Base.Fix2(isa, DynamicPPL.VarInfo), varinfos) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos x = values_as(varinfo, Vector) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index af24be86f..de7a7c186 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -586,9 +586,7 @@ end value_true = DynamicPPL.TestUtils.rand_prior_true(model) vns = DynamicPPL.TestUtils.varnames(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=false - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) # Filter out those which are not based on `VarNamedVector`. varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) # Get the true log joint.