diff --git a/HISTORY.md b/HISTORY.md index e4fca51e7..a0f91a494 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.36.11 + +Make `ThreadSafeVarInfo` hold a total of `Threads.nthreads() * 2` logp values, instead of just `Threads.nthreads()`. +This fix helps to paper over the cracks in using `threadid()` to index into the `ThreadSafeVarInfo` object. + ## 0.36.10 Added compatibility with ForwardDiff 1.0. diff --git a/Project.toml b/Project.toml index fe4d69ef5..362035eb7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.10" +version = "0.36.11" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abf14b8fc..3ae425896 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -648,10 +648,12 @@ end # Threadsafe stuff. # For `SimpleVarInfo` we don't really need `Ref` so let's not use it. function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads())) + return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads() * 2)) end function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) + return ThreadSafeVarInfo( + vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] + ) end has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2dc2645de..bd1876a19 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -9,7 +9,18 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo logps::L end function ThreadSafeVarInfo(vi::AbstractVarInfo) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) + # In ThreadSafeVarInfo we use threadid() to index into the array of logp + # fields. This is not good practice --- see + # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full + # explanation --- but it has worked okay so far. + # The use of nthreads()*2 here ensures that threadid() doesn't exceed + # the length of the logps array. Ideally, we would use maxthreadid(), + # but Mooncake can't differentiate through that. Empirically, nthreads()*2 + # seems to provide an upper bound to maxthreadid(), so we use that here. + # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 + return ThreadSafeVarInfo( + vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] + ) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 72c439db8..ededf78b0 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -5,7 +5,7 @@ @test threadsafe_vi.varinfo === vi @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() + @test length(threadsafe_vi.logps) == Threads.nthreads() * 2 @test all(iszero(x[]) for x in threadsafe_vi.logps) end