Skip to content

Commit bb6501f

Browse files
committed
Merge remote-tracking branch 'origin/mt/fixes_threaded' into fixes_threaded
2 parents b81363a + 008392a commit bb6501f

File tree

4 files changed

+11
-20
lines changed

4 files changed

+11
-20
lines changed

src/compat/ad.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,3 @@ ZygoteRules.@adjoint function push!(
99
)
1010
return push!(vi, vn, r, dist, gidset), _ -> nothing
1111
end
12-
13-
# Multithreaded evaluation is not compatible with Zygote.
14-
ZygoteRules.@adjoint function (model::Model)(
15-
vi::AbstractVarInfo,
16-
spl::AbstractSampler,
17-
ctx::AbstractContext
18-
)
19-
function evaluate(vi, spl, ctx)
20-
return evaluate_singlethreaded(model, vi, spl, ctx)
21-
end
22-
return ZygoteRules.pullback(evaluate, vi, spl, ctx)
23-
end
24-

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function evaluate_multithreaded(model, varinfo, sampler, context)
154154
end
155155
wrapper = ThreadSafeVarInfo(varinfo)
156156
result = model.f(model, wrapper, sampler, context)
157-
acclogp!(varinfo, sum(wrapper.logps))
157+
setlogp!(varinfo, getlogp(wrapper))
158158
return result
159159
end
160160

src/threadsafe.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,33 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
99
logps::L
1010
end
1111
function ThreadSafeVarInfo(vi::AbstractVarInfo)
12-
return ThreadSafeVarInfo(vi, [zero(getlogp(vi)) for _ in 1:Threads.nthreads()])
12+
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
1313
end
1414
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
1515

1616
# Instead of updating the log probability of the underlying variables we
1717
# just update the array of log probabilities.
1818
function acclogp!(vi::ThreadSafeVarInfo, logp)
19-
vi.logps[Threads.threadid()] += logp
19+
vi.logps[Threads.threadid()][] += logp
2020
return vi
2121
end
2222

2323
# The current log probability of the variables has to be computed from
2424
# both the wrapped variables and the thread-specific log probabilities.
25-
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps)
25+
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps)
2626

2727
# TODO: Make remaining methods thread-safe.
2828

2929
function resetlogp!(vi::ThreadSafeVarInfo)
30-
fill!(vi.logps, zero(getlogp(vi)))
30+
for x in vi.logps
31+
x[] = zero(x[])
32+
end
3133
return resetlogp!(vi.varinfo)
3234
end
3335
function setlogp!(vi::ThreadSafeVarInfo, logp)
34-
fill!(vi.logps, zero(logp))
36+
for x in vi.logps
37+
x[] = zero(x[])
38+
end
3539
return setlogp!(vi.varinfo, logp)
3640
end
3741

test/compat/ad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ using Tracker
6262

6363
y, back = Zygote.pullback(logp_model, x)
6464
@test y lp
65-
@test back(1) grad
65+
@test back(1)[1] grad
6666
end
6767

0 commit comments

Comments
 (0)