Skip to content

Commit 16fd886

Browse files
committed
Add more tests of AD and threading, independently of Turing
1 parent bb6501f commit 16fd886

File tree

5 files changed

+171
-122
lines changed

5 files changed

+171
-122
lines changed

test/compat/ad.jl

Lines changed: 25 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,27 @@
1-
using DynamicPPL
2-
using Distributions
3-
4-
using ForwardDiff
5-
using Zygote
6-
using Tracker
7-
8-
@testset "logp" begin
9-
@model function admodel()
10-
s ~ InverseGamma(2, 3)
11-
m ~ Normal(0, sqrt(s))
12-
1.5 ~ Normal(m, sqrt(s))
13-
2.0 ~ Normal(m, sqrt(s))
14-
return s, m
15-
end
16-
17-
model = admodel()
18-
vi = VarInfo(model)
19-
model(vi, SampleFromPrior())
20-
x = [vi[@varname(s)], vi[@varname(m)]]
21-
22-
dist_s = InverseGamma(2,3)
23-
24-
# Hand-written log probabilities for vector `x = [s, m]`.
25-
function logp_manual(x)
26-
s = x[1]
27-
m = x[2]
28-
dist = Normal(m, sqrt(s))
29-
30-
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +
31-
logpdf(dist, 1.5) + logpdf(dist, 2.0)
32-
end
33-
34-
# Log probabilities for vector `x = [s, m]` using the model.
35-
function logp_model(x)
36-
new_vi = VarInfo(vi, SampleFromPrior(), x)
37-
model(new_vi, SampleFromPrior())
38-
return getlogp(new_vi)
1+
@testset "ad.jl" begin
2+
@testset "logp" begin
3+
# Hand-written log probabilities for vector `x = [s, m]`.
4+
function logp_gdemo_default(x)
5+
s = x[1]
6+
m = x[2]
7+
dist = Normal(m, sqrt(s))
8+
9+
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +
10+
logpdf(dist, 1.5) + logpdf(dist, 2.0)
11+
end
12+
13+
test_model_ad(gdemo_default, logp_gdemo_default)
14+
15+
@model function wishart_ad()
16+
v ~ Wishart(7, [1 0.5; 0.5 1])
17+
end
18+
19+
# Hand-written log probabilities for `x = [v]`.
20+
function logp_wishart_ad(x)
21+
dist = Wishart(7, [1 0.5; 0.5 1])
22+
return logpdf(dist, reshape(x, 2, 2))
23+
end
24+
25+
test_model_ad(wishart_ad(), logp_wishart_ad)
3926
end
40-
41-
# Check that both functions return the same values.
42-
lp = logp_manual(x)
43-
@test logp_model(x) lp
44-
45-
# Gradients based on the manual implementation.
46-
grad = ForwardDiff.gradient(logp_manual, x)
47-
48-
y, back = Tracker.forward(logp_manual, x)
49-
@test Tracker.data(y) lp
50-
@test Tracker.data(back(1)[1]) grad
51-
52-
y, back = Zygote.pullback(logp_manual, x)
53-
@test y lp
54-
@test back(1)[1] grad
55-
56-
# Gradients based on the model.
57-
@test ForwardDiff.gradient(logp_model, x) grad
58-
59-
y, back = Tracker.forward(logp_model, x)
60-
@test Tracker.data(y) lp
61-
@test Tracker.data(back(1)[1]) grad
62-
63-
y, back = Zygote.pullback(logp_model, x)
64-
@test y lp
65-
@test back(1)[1] grad
6627
end
67-

test/compiler.jl

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -580,59 +580,4 @@ end
580580
model = demo()
581581
@test all(iszero(model()) for _ in 1:1000)
582582
end
583-
@testset "threading" begin
584-
println("Peforming threading tests with $(Threads.nthreads()) threads")
585-
586-
x = rand(10_000)
587-
588-
@model function wthreads(x)
589-
x[1] ~ Normal(0, 1)
590-
Threads.@threads for i in 2:length(x)
591-
x[i] ~ Normal(x[i-1], 1)
592-
end
593-
end
594-
595-
vi = VarInfo()
596-
wthreads(x)(vi)
597-
lp_w_threads = getlogp(vi)
598-
599-
println("With `@threads`:")
600-
println(" default:")
601-
@time wthreads(x)(vi)
602-
603-
# Ensure that we use `ThreadSafeVarInfo`.
604-
@test getlogp(vi) lp_w_threads
605-
DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
606-
DefaultContext())
607-
608-
println(" evaluate_multithreaded:")
609-
@time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
610-
DefaultContext())
611-
612-
@model function wothreads(x)
613-
x[1] ~ Normal(0, 1)
614-
for i in 2:length(x)
615-
x[i] ~ Normal(x[i-1], 1)
616-
end
617-
end
618-
619-
vi = VarInfo()
620-
wothreads(x)(vi)
621-
lp_wo_threads = getlogp(vi)
622-
623-
println("Without `@threads`:")
624-
println(" default:")
625-
@time wothreads(x)(vi)
626-
627-
@test lp_w_threads lp_wo_threads
628-
629-
# Ensure that we use `VarInfo`.
630-
DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
631-
DefaultContext())
632-
@test getlogp(vi) lp_w_threads
633-
634-
println(" evaluate_singlethreaded:")
635-
@time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
636-
DefaultContext())
637-
end
638583
end

test/runtests.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
1-
using Test, DynamicPPL
1+
using DynamicPPL
2+
using Distributions
3+
using ForwardDiff
4+
using Tracker
5+
using Zygote
6+
7+
using Random
8+
using Test
9+
210
dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1]
311
include(dir*"/test/Turing/Turing.jl")
412
using .Turing
513

614
turnprogress(false)
715

16+
include("test_util.jl")
17+
818
@testset "DynamicPPL.jl" begin
919
include("utils.jl")
1020
include("compiler.jl")
11-
include("compat/ad.jl")
1221
include("varinfo.jl")
1322
include("sampler.jl")
1423
include("prob_macro.jl")
1524
include("independence.jl")
1625
include("distribution_wrappers.jl")
26+
27+
include("threadsafe.jl")
28+
29+
@testset "compat" begin
30+
include(joinpath("compat", "ad.jl"))
31+
end
1732
end

test/test_util.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
function test_model_ad(model, logp_manual)
2+
vi = VarInfo(model)
3+
model(vi, SampleFromPrior())
4+
x = DynamicPPL.getall(vi)
5+
6+
# Log probabilities using the model.
7+
function logp_model(x)
8+
new_vi = VarInfo(vi, SampleFromPrior(), x)
9+
model(new_vi, SampleFromPrior())
10+
return getlogp(new_vi)
11+
end
12+
13+
# Check that both functions return the same values.
14+
lp = logp_manual(x)
15+
@test logp_model(x) lp
16+
17+
# Gradients based on the manual implementation.
18+
grad = ForwardDiff.gradient(logp_manual, x)
19+
20+
y, back = Tracker.forward(logp_manual, x)
21+
@test Tracker.data(y) lp
22+
@test Tracker.data(back(1)[1]) grad
23+
24+
y, back = Zygote.pullback(logp_manual, x)
25+
@test y lp
26+
@test back(1)[1] grad
27+
28+
# Gradients based on the model.
29+
@test ForwardDiff.gradient(logp_model, x) grad
30+
31+
y, back = Tracker.forward(logp_model, x)
32+
@test Tracker.data(y) lp
33+
@test Tracker.data(back(1)[1]) grad
34+
35+
y, back = Zygote.pullback(logp_model, x)
36+
@test y lp
37+
@test back(1)[1] grad
38+
end

test/threadsafe.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
@testset "threadsafe.jl" begin
2+
@testset "constructor" begin
3+
vi = VarInfo(gdemo_default)
4+
threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)
5+
6+
@test threadsafe_vi.varinfo === vi
7+
@test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))}
8+
@test length(threadsafe_vi.logps) == Threads.nthreads()
9+
@test all(iszero(x[]) for x in threadsafe_vi.logps)
10+
end
11+
12+
# TODO: Add more tests of the public API
13+
@testset "API" begin
14+
vi = VarInfo(gdemo_default)
15+
threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi)
16+
17+
lp = getlogp(vi)
18+
@test getlogp(threadsafe_vi) == lp
19+
20+
acclogp!(threadsafe_vi, 42)
21+
@test threadsafe_vi.logps[Threads.threadid()][] == 42
22+
@test getlogp(vi) == lp
23+
@test getlogp(threadsafe_vi) == lp + 42
24+
25+
resetlogp!(threadsafe_vi)
26+
@test iszero(getlogp(vi))
27+
@test iszero(getlogp(threadsafe_vi))
28+
@test all(iszero(x[]) for x in threadsafe_vi.logps)
29+
30+
setlogp!(threadsafe_vi, 42)
31+
@test getlogp(vi) == 42
32+
@test getlogp(threadsafe_vi) == 42
33+
@test all(iszero(x[]) for x in threadsafe_vi.logps)
34+
end
35+
36+
@testset "model" begin
37+
println("Peforming threading tests with $(Threads.nthreads()) threads")
38+
39+
x = rand(10_000)
40+
41+
@model function wthreads(x)
42+
x[1] ~ Normal(0, 1)
43+
Threads.@threads for i in 2:length(x)
44+
x[i] ~ Normal(x[i-1], 1)
45+
end
46+
end
47+
48+
vi = VarInfo()
49+
wthreads(x)(vi)
50+
lp_w_threads = getlogp(vi)
51+
52+
println("With `@threads`:")
53+
println(" default:")
54+
@time wthreads(x)(vi)
55+
56+
# Ensure that we use `ThreadSafeVarInfo`.
57+
@test getlogp(vi) lp_w_threads
58+
DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
59+
DefaultContext())
60+
61+
println(" evaluate_multithreaded:")
62+
@time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
63+
DefaultContext())
64+
65+
@model function wothreads(x)
66+
x[1] ~ Normal(0, 1)
67+
for i in 2:length(x)
68+
x[i] ~ Normal(x[i-1], 1)
69+
end
70+
end
71+
72+
vi = VarInfo()
73+
wothreads(x)(vi)
74+
lp_wo_threads = getlogp(vi)
75+
76+
println("Without `@threads`:")
77+
println(" default:")
78+
@time wothreads(x)(vi)
79+
80+
@test lp_w_threads lp_wo_threads
81+
82+
# Ensure that we use `VarInfo`.
83+
DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
84+
DefaultContext())
85+
@test getlogp(vi) lp_w_threads
86+
87+
println(" evaluate_singlethreaded:")
88+
@time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
89+
DefaultContext())
90+
end
91+
end

0 commit comments

Comments
 (0)