Skip to content

Commit dca3aa0

Browse files
authored
Add tests
1 parent 05f91cf commit dca3aa0

File tree

1 file changed

+74
-46
lines changed

1 file changed

+74
-46
lines changed

test/sample.jl

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -103,61 +103,59 @@
103103
end
104104
end
105105

106-
if VERSION v"1.3"
107-
@testset "Multithreaded sampling" begin
108-
if Threads.nthreads() == 1
109-
warnregex = r"^Only a single thread available"
110-
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(),
111-
10, 10)
112-
end
106+
@testset "Multithreaded sampling" begin
107+
if Threads.nthreads() == 1
108+
warnregex = r"^Only a single thread available"
109+
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(),
110+
10, 10)
111+
end
113112

114-
# No dedicated chains type
115-
N = 10_000
116-
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000)
117-
@test chains isa Vector{<:Vector{<:MySample}}
118-
@test length(chains) == 1000
119-
@test all(length(x) == N for x in chains)
113+
# No dedicated chains type
114+
N = 10_000
115+
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000)
116+
@test chains isa Vector{<:Vector{<:MySample}}
117+
@test length(chains) == 1000
118+
@test all(length(x) == N for x in chains)
120119

121-
Random.seed!(1234)
122-
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
123-
chain_type = MyChain)
120+
Random.seed!(1234)
121+
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
122+
chain_type = MyChain)
124123

125-
# test output type and size
126-
@test chains isa Vector{<:MyChain}
127-
@test length(chains) == 1000
128-
@test all(x -> length(x.as) == length(x.bs) == N, chains)
124+
# test output type and size
125+
@test chains isa Vector{<:MyChain}
126+
@test length(chains) == 1000
127+
@test all(x -> length(x.as) == length(x.bs) == N, chains)
129128

130-
# test some statistical properties
131-
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
132-
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
133-
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
134-
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
129+
# test some statistical properties
130+
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
131+
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
132+
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
133+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
135134

136-
# test reproducibility
137-
Random.seed!(1234)
138-
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
139-
chain_type = MyChain)
135+
# test reproducibility
136+
Random.seed!(1234)
137+
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
138+
chain_type = MyChain)
140139

141-
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
142-
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
140+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
141+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
143142

144-
# Unexpected order of arguments.
145-
str = "Number of chains (10) is greater than number of samples per chain (5)"
146-
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
147-
MCMCThreads(), 5, 10;
148-
chain_type = MyChain)
143+
# Unexpected order of arguments.
144+
str = "Number of chains (10) is greater than number of samples per chain (5)"
145+
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
146+
MCMCThreads(), 5, 10;
147+
chain_type = MyChain)
149148

150-
# Suppress output.
151-
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
152-
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
153-
progress = false, chain_type = MyChain)
154-
end
155-
@test all(l.level > Logging.LogLevel(-1) for l in logs)
149+
# Suppress output.
150+
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
151+
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
152+
progress = false, chain_type = MyChain)
153+
end
154+
@test all(l.level > Logging.LogLevel(-1) for l in logs)
156155

157-
# Smoke test for nchains < nthreads
158-
if Threads.nthreads() == 2
159-
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
160-
end
156+
# Smoke test for nchains < nthreads
157+
if Threads.nthreads() == 2
158+
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
161159
end
162160
end
163161

@@ -271,6 +269,36 @@
271269
@test all(l.level > Logging.LogLevel(-1) for l in logs)
272270
end
273271

272+
@testset "Ensemble sampling: Reproducibility" begin
273+
N = 1_000
274+
nchains = 10
275+
276+
# Serial sampling
277+
Random.seed!(1234)
278+
chains_serial = sample(
279+
MyModel(), MySampler(), MCMCSerial(), N, nchains;
280+
progress=false, chain_type=MyChain
281+
)
282+
283+
# Multi-threaded sampling
284+
Random.seed!(1234)
285+
chains_threads = sample(
286+
MyModel(), MySampler(), MCMCThreads(), N, nchains;
287+
progress=false, chain_type=MyChain
288+
)
289+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N)
290+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N)
291+
292+
# Multi-core sampling
293+
Random.seed!(1234)
294+
chains_distributed = sample(
295+
MyModel(), MySampler(), MCMCDistributed(), N, nchains;
296+
progress=false, chain_type=MyChain
297+
)
298+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N)
299+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N)
300+
end
301+
274302
@testset "Chain constructors" begin
275303
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
276304
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)

0 commit comments

Comments
 (0)