Skip to content

Commit 07c0cee

Browse files
authored
Merge pull request #97 from TuringLang/fixes_threaded
2 parents 0758d01 + 6d8caca commit 07c0cee

File tree

9 files changed

+189
-46
lines changed

9 files changed

+189
-46
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
33
authors = ["mohamed82008 <[email protected]>"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -45,6 +45,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
4545
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4646
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4747
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
48+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4849

4950
[targets]
50-
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
51+
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]

src/context_implementations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ function get_and_set_val!(
302302
else
303303
r = init(dist, spl, n)
304304
for i in 1:n
305-
push!(vi, vns[i], r[:,i], dist, spl)
305+
vn = vns[i]
306+
push!(vi, vn, r[:,i], dist, spl)
306307
settrans!(vi, false, vn)
307308
end
308309
end

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function evaluate_multithreaded(model, varinfo, sampler, context)
154154
end
155155
wrapper = ThreadSafeVarInfo(varinfo)
156156
result = model.f(model, wrapper, sampler, context)
157-
acclogp!(varinfo, sum(wrapper.logps))
157+
setlogp!(varinfo, getlogp(wrapper))
158158
return result
159159
end
160160

src/threadsafe.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,33 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
99
logps::L
1010
end
1111
function ThreadSafeVarInfo(vi::AbstractVarInfo)
12-
return ThreadSafeVarInfo(vi, [zero(getlogp(vi)) for _ in 1:Threads.nthreads()])
12+
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
1313
end
1414
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
1515

1616
# Instead of updating the log probability of the underlying variables we
1717
# just update the array of log probabilities.
1818
function acclogp!(vi::ThreadSafeVarInfo, logp)
19-
vi.logps[Threads.threadid()] += logp
19+
vi.logps[Threads.threadid()][] += logp
2020
return vi
2121
end
2222

2323
# The current log probability of the variables has to be computed from
2424
# both the wrapped variables and the thread-specific log probabilities.
25-
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps)
25+
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps)
2626

2727
# TODO: Make remaining methods thread-safe.
2828

2929
function resetlogp!(vi::ThreadSafeVarInfo)
30-
fill!(vi.logps, zero(getlogp(vi)))
30+
for x in vi.logps
31+
x[] = zero(x[])
32+
end
3133
return resetlogp!(vi.varinfo)
3234
end
3335
function setlogp!(vi::ThreadSafeVarInfo, logp)
34-
fill!(vi.logps, zero(logp))
36+
for x in vi.logps
37+
x[] = zero(x[])
38+
end
3539
return setlogp!(vi.varinfo, logp)
3640
end
3741

@@ -45,6 +49,7 @@ syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
4549
function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
4650
setgid!(vi.varinfo, gid, vn)
4751
end
52+
setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index)
4853
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
4954

5055
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)

test/compat/ad.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
@testset "ad.jl" begin
2+
@testset "logp" begin
3+
# Hand-written log probabilities for vector `x = [s, m]`.
4+
function logp_gdemo_default(x)
5+
s = x[1]
6+
m = x[2]
7+
dist = Normal(m, sqrt(s))
8+
9+
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +
10+
logpdf(dist, 1.5) + logpdf(dist, 2.0)
11+
end
12+
13+
test_model_ad(gdemo_default, logp_gdemo_default)
14+
15+
@model function wishart_ad()
16+
v ~ Wishart(7, [1 0.5; 0.5 1])
17+
end
18+
19+
# Hand-written log probabilities for `x = [v]`.
20+
function logp_wishart_ad(x)
21+
dist = Wishart(7, [1 0.5; 0.5 1])
22+
return logpdf(dist, reshape(x, 2, 2))
23+
end
24+
25+
test_model_ad(wishart_ad(), logp_wishart_ad)
26+
end
27+
end

test/compiler.jl

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -580,42 +580,6 @@ end
580580
model = demo()
581581
@test all(iszero(model()) for _ in 1:1000)
582582
end
583-
@testset "threading" begin
584-
@info "Peforming threading tests with $(Threads.nthreads()) threads"
585-
586-
x = rand(10_000)
587-
588-
@model function wthreads(x)
589-
x[1] ~ Normal(0, 1)
590-
Threads.@threads for i in 2:length(x)
591-
x[i] ~ Normal(x[i-1], 1)
592-
end
593-
end
594-
595-
vi = VarInfo()
596-
wthreads(x)(vi)
597-
lp_w_threads = getlogp(vi)
598-
599-
println("With threading:")
600-
@time wthreads(x)(vi)
601-
602-
@model function wothreads(x)
603-
x[1] ~ Normal(0, 1)
604-
for i in 2:length(x)
605-
x[i] ~ Normal(x[i-1], 1)
606-
end
607-
end
608-
609-
vi = VarInfo()
610-
wothreads(x)(vi)
611-
lp_wo_threads = getlogp(vi)
612-
613-
println("Without threading:")
614-
@time wothreads(x)(vi)
615-
616-
@test lp_w_threads lp_wo_threads
617-
end
618-
619583
@testset "docstring" begin
620584
"This is a test"
621585
@model function demo(x)

test/runtests.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1-
using Test, DynamicPPL
1+
using DynamicPPL
2+
using Distributions
3+
using ForwardDiff
4+
using Tracker
5+
using Zygote
6+
7+
using Random
8+
using Test
9+
210
dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1]
311
include(dir*"/test/Turing/Turing.jl")
412
using .Turing
513

614
turnprogress(false)
715

16+
include("test_util.jl")
17+
818
@testset "DynamicPPL.jl" begin
919
include("utils.jl")
1020
include("compiler.jl")
@@ -13,4 +23,10 @@ turnprogress(false)
1323
include("prob_macro.jl")
1424
include("independence.jl")
1525
include("distribution_wrappers.jl")
26+
27+
include("threadsafe.jl")
28+
29+
@testset "compat" begin
30+
include(joinpath("compat", "ad.jl"))
31+
end
1632
end

test/test_util.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
function test_model_ad(model, logp_manual)
2+
vi = VarInfo(model)
3+
model(vi, SampleFromPrior())
4+
x = DynamicPPL.getall(vi)
5+
6+
# Log probabilities using the model.
7+
function logp_model(x)
8+
new_vi = VarInfo(vi, SampleFromPrior(), x)
9+
model(new_vi, SampleFromPrior())
10+
return getlogp(new_vi)
11+
end
12+
13+
# Check that both functions return the same values.
14+
lp = logp_manual(x)
15+
@test logp_model(x) lp
16+
17+
# Gradients based on the manual implementation.
18+
grad = ForwardDiff.gradient(logp_manual, x)
19+
20+
y, back = Tracker.forward(logp_manual, x)
21+
@test Tracker.data(y) lp
22+
@test Tracker.data(back(1)[1]) grad
23+
24+
y, back = Zygote.pullback(logp_manual, x)
25+
@test y lp
26+
@test back(1)[1] grad
27+
28+
# Gradients based on the model.
29+
@test ForwardDiff.gradient(logp_model, x) grad
30+
31+
y, back = Tracker.forward(logp_model, x)
32+
@test Tracker.data(y) lp
33+
@test Tracker.data(back(1)[1]) grad
34+
35+
y, back = Zygote.pullback(logp_model, x)
36+
@test y lp
37+
@test back(1)[1] grad
38+
end

test/threadsafe.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
@testset "threadsafe.jl" begin
2+
@testset "constructor" begin
3+
vi = VarInfo(gdemo_default)
4+
threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)
5+
6+
@test threadsafe_vi.varinfo === vi
7+
@test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))}
8+
@test length(threadsafe_vi.logps) == Threads.nthreads()
9+
@test all(iszero(x[]) for x in threadsafe_vi.logps)
10+
end
11+
12+
# TODO: Add more tests of the public API
13+
@testset "API" begin
14+
vi = VarInfo(gdemo_default)
15+
threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi)
16+
17+
lp = getlogp(vi)
18+
@test getlogp(threadsafe_vi) == lp
19+
20+
acclogp!(threadsafe_vi, 42)
21+
@test threadsafe_vi.logps[Threads.threadid()][] == 42
22+
@test getlogp(vi) == lp
23+
@test getlogp(threadsafe_vi) == lp + 42
24+
25+
resetlogp!(threadsafe_vi)
26+
@test iszero(getlogp(vi))
27+
@test iszero(getlogp(threadsafe_vi))
28+
@test all(iszero(x[]) for x in threadsafe_vi.logps)
29+
30+
setlogp!(threadsafe_vi, 42)
31+
@test getlogp(vi) == 42
32+
@test getlogp(threadsafe_vi) == 42
33+
@test all(iszero(x[]) for x in threadsafe_vi.logps)
34+
end
35+
36+
@testset "model" begin
37+
println("Peforming threading tests with $(Threads.nthreads()) threads")
38+
39+
x = rand(10_000)
40+
41+
@model function wthreads(x)
42+
x[1] ~ Normal(0, 1)
43+
Threads.@threads for i in 2:length(x)
44+
x[i] ~ Normal(x[i-1], 1)
45+
end
46+
end
47+
48+
vi = VarInfo()
49+
wthreads(x)(vi)
50+
lp_w_threads = getlogp(vi)
51+
52+
println("With `@threads`:")
53+
println(" default:")
54+
@time wthreads(x)(vi)
55+
56+
# Ensure that we use `ThreadSafeVarInfo`.
57+
@test getlogp(vi) lp_w_threads
58+
DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
59+
DefaultContext())
60+
61+
println(" evaluate_multithreaded:")
62+
@time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
63+
DefaultContext())
64+
65+
@model function wothreads(x)
66+
x[1] ~ Normal(0, 1)
67+
for i in 2:length(x)
68+
x[i] ~ Normal(x[i-1], 1)
69+
end
70+
end
71+
72+
vi = VarInfo()
73+
wothreads(x)(vi)
74+
lp_wo_threads = getlogp(vi)
75+
76+
println("Without `@threads`:")
77+
println(" default:")
78+
@time wothreads(x)(vi)
79+
80+
@test lp_w_threads lp_wo_threads
81+
82+
# Ensure that we use `VarInfo`.
83+
DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
84+
DefaultContext())
85+
@test getlogp(vi) lp_w_threads
86+
87+
println(" evaluate_singlethreaded:")
88+
@time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
89+
DefaultContext())
90+
end
91+
end

0 commit comments

Comments
 (0)