Skip to content

Commit 924f635

Browse files
committed
fix Zygote support
1 parent b81363a commit 924f635

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

src/compat/ad.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,12 @@ ZygoteRules.@adjoint function push!(
77
dist::Distribution,
88
gidset::Set{Selector}
99
)
10-
return push!(vi, vn, r, dist, gidset), _ -> nothing
10+
return push!(vi, vn, r, dist, gidset), _ -> ntuple(_ -> nothing, 5)
1111
end
1212

13-
# Multithreaded evaluation is not compatible with Zygote.
14-
ZygoteRules.@adjoint function (model::Model)(
15-
vi::AbstractVarInfo,
16-
spl::AbstractSampler,
17-
ctx::AbstractContext
18-
)
19-
function evaluate(vi, spl, ctx)
20-
return evaluate_singlethreaded(model, vi, spl, ctx)
21-
end
22-
return ZygoteRules.pullback(evaluate, vi, spl, ctx)
13+
ZygoteRules.@adjoint function Threads.nthreads()
14+
Threads.nthreads(), _ -> (nothing,)
15+
end
16+
ZygoteRules.@adjoint function Threads.threadid()
17+
Threads.threadid(), _ -> (nothing,)
2318
end
24-

src/threadsafe.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
1+
#################
2+
# VectorOfLogps #
3+
#################
4+
5+
struct VectorOfLogps{T1, T2 <: Vector{Base.RefValue{T1}}}
6+
v::T2
7+
end
8+
VectorOfLogps(::Type{T}, n::Int) where {T} = VectorOfLogps(zero(T), n)
9+
function VectorOfLogps(val::T, n::Int) where {T}
10+
v = [val for i in 1:Threads.nthreads()]
11+
return VectorOfLogps(v)
12+
end
13+
VectorOfLogps(v::Vector) = VectorOfLogps(Ref.(v))
14+
Base.getindex(v::VectorOfLogps, i::Integer) = v.v[i][]
15+
function Base.setindex!(v::VectorOfLogps, val, i::Integer)
16+
v.v[i][] = val
17+
return v
18+
end
19+
Base.sum(v::VectorOfLogps) = sum(v -> v[], v.v)
20+
function Base.fill!(v::VectorOfLogps, val)
21+
for i in 1:length(v.v)
22+
v.v[i][] = val
23+
end
24+
return v
25+
end
26+
27+
128
"""
229
ThreadSafeVarInfo
330
@@ -9,7 +36,7 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
936
logps::L
1037
end
1138
function ThreadSafeVarInfo(vi::AbstractVarInfo)
12-
return ThreadSafeVarInfo(vi, [zero(getlogp(vi)) for _ in 1:Threads.nthreads()])
39+
return ThreadSafeVarInfo(vi, VectorOfLogps(zero(getlogp(vi)), Threads.nthreads()))
1340
end
1441
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
1542

0 commit comments

Comments
 (0)