|
103 | 103 | end
|
104 | 104 | end
|
105 | 105 |
|
106 |
| - if VERSION ≥ v"1.3" |
107 |
| - @testset "Multithreaded sampling" begin |
108 |
| - if Threads.nthreads() == 1 |
109 |
| - warnregex = r"^Only a single thread available" |
110 |
| - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), |
111 |
| - 10, 10) |
112 |
| - end |
| 106 | + @testset "Multithreaded sampling" begin |
| 107 | + if Threads.nthreads() == 1 |
| 108 | + warnregex = r"^Only a single thread available" |
| 109 | + @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), |
| 110 | + 10, 10) |
| 111 | + end |
113 | 112 |
|
114 |
| - # No dedicated chains type |
115 |
| - N = 10_000 |
116 |
| - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) |
117 |
| - @test chains isa Vector{<:Vector{<:MySample}} |
118 |
| - @test length(chains) == 1000 |
119 |
| - @test all(length(x) == N for x in chains) |
| 113 | + # No dedicated chains type |
| 114 | + N = 10_000 |
| 115 | + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) |
| 116 | + @test chains isa Vector{<:Vector{<:MySample}} |
| 117 | + @test length(chains) == 1000 |
| 118 | + @test all(length(x) == N for x in chains) |
120 | 119 |
|
121 |
| - Random.seed!(1234) |
122 |
| - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; |
123 |
| - chain_type = MyChain) |
| 120 | + Random.seed!(1234) |
| 121 | + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; |
| 122 | + chain_type = MyChain) |
124 | 123 |
|
125 |
| - # test output type and size |
126 |
| - @test chains isa Vector{<:MyChain} |
127 |
| - @test length(chains) == 1000 |
128 |
| - @test all(x -> length(x.as) == length(x.bs) == N, chains) |
| 124 | + # test output type and size |
| 125 | + @test chains isa Vector{<:MyChain} |
| 126 | + @test length(chains) == 1000 |
| 127 | + @test all(x -> length(x.as) == length(x.bs) == N, chains) |
129 | 128 |
|
130 |
| - # test some statistical properties |
131 |
| - @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) |
132 |
| - @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) |
133 |
| - @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) |
134 |
| - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) |
| 129 | + # test some statistical properties |
| 130 | + @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) |
| 131 | + @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) |
| 132 | + @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) |
| 133 | + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) |
135 | 134 |
|
136 |
| - # test reproducibility |
137 |
| - Random.seed!(1234) |
138 |
| - chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; |
139 |
| - chain_type = MyChain) |
| 135 | + # test reproducibility |
| 136 | + Random.seed!(1234) |
| 137 | + chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; |
| 138 | + chain_type = MyChain) |
140 | 139 |
|
141 |
| - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) |
142 |
| - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) |
| 140 | + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) |
| 141 | + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) |
143 | 142 |
|
144 |
| - # Unexpected order of arguments. |
145 |
| - str = "Number of chains (10) is greater than number of samples per chain (5)" |
146 |
| - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), |
147 |
| - MCMCThreads(), 5, 10; |
148 |
| - chain_type = MyChain) |
| 143 | + # Unexpected order of arguments. |
| 144 | + str = "Number of chains (10) is greater than number of samples per chain (5)" |
| 145 | + @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), |
| 146 | + MCMCThreads(), 5, 10; |
| 147 | + chain_type = MyChain) |
149 | 148 |
|
150 |
| - # Suppress output. |
151 |
| - logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do |
152 |
| - sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; |
153 |
| - progress = false, chain_type = MyChain) |
154 |
| - end |
155 |
| - @test all(l.level > Logging.LogLevel(-1) for l in logs) |
| 149 | + # Suppress output. |
| 150 | + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do |
| 151 | + sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; |
| 152 | + progress = false, chain_type = MyChain) |
| 153 | + end |
| 154 | + @test all(l.level > Logging.LogLevel(-1) for l in logs) |
156 | 155 |
|
157 |
| - # Smoke test for nchains < nthreads |
158 |
| - if Threads.nthreads() == 2 |
159 |
| - sample(MyModel(), MySampler(), MCMCThreads(), N, 1) |
160 |
| - end |
| 156 | + # Smoke test for nchains < nthreads |
| 157 | + if Threads.nthreads() == 2 |
| 158 | + sample(MyModel(), MySampler(), MCMCThreads(), N, 1) |
161 | 159 | end
|
162 | 160 | end
|
163 | 161 |
|
|
201 | 199 | @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
|
202 | 200 | @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
|
203 | 201 | @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
|
204 |
| - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) |
| 202 | + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) |
205 | 203 |
|
206 | 204 | # Test reproducibility.
|
207 | 205 | Random.seed!(1234)
|
|
247 | 245 | @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
|
248 | 246 | @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
|
249 | 247 | @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
|
250 |
| - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) |
| 248 | + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) |
251 | 249 |
|
252 | 250 | # Test reproducibility.
|
253 | 251 | Random.seed!(1234)
|
|
271 | 269 | @test all(l.level > Logging.LogLevel(-1) for l in logs)
|
272 | 270 | end
|
273 | 271 |
|
| 272 | + @testset "Ensemble sampling: Reproducibility" begin |
| 273 | + N = 1_000 |
| 274 | + nchains = 10 |
| 275 | + |
| 276 | + # Serial sampling |
| 277 | + Random.seed!(1234) |
| 278 | + chains_serial = sample( |
| 279 | + MyModel(), MySampler(), MCMCSerial(), N, nchains; |
| 280 | + progress=false, chain_type=MyChain |
| 281 | + ) |
| 282 | + |
| 283 | + # Multi-threaded sampling |
| 284 | + Random.seed!(1234) |
| 285 | + chains_threads = sample( |
| 286 | + MyModel(), MySampler(), MCMCThreads(), N, nchains; |
| 287 | + progress=false, chain_type=MyChain |
| 288 | + ) |
| 289 | + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) |
| 290 | + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) |
| 291 | + |
| 292 | + # Multi-core sampling |
| 293 | + Random.seed!(1234) |
| 294 | + chains_distributed = sample( |
| 295 | + MyModel(), MySampler(), MCMCDistributed(), N, nchains; |
| 296 | + progress=false, chain_type=MyChain |
| 297 | + ) |
| 298 | + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) |
| 299 | + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) |
| 300 | + end |
| 301 | + |
274 | 302 | @testset "Chain constructors" begin
|
275 | 303 | chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
|
276 | 304 | chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)
|
|
0 commit comments