|
225 | 225 | @test all(l.level > Logging.LogLevel(-1) for l in logs)
|
226 | 226 | end
|
227 | 227 |
|
| 228 | + @testset "Serial sampling" begin |
| 229 | + # No dedicated chains type |
| 230 | + N = 10_000 |
| 231 | + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000) |
| 232 | + @test chains isa Vector{<:Vector{<:MySample}} |
| 233 | + @test length(chains) == 1000 |
| 234 | + @test all(length(x) == N for x in chains) |
| 235 | + |
| 236 | + Random.seed!(1234) |
| 237 | + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; |
| 238 | + chain_type = MyChain) |
| 239 | + |
| 240 | + # Test output type and size. |
| 241 | + @test chains isa Vector{<:MyChain} |
| 242 | + @test all(c.as[1] === missing for c in chains) |
| 243 | + @test length(chains) == 1000 |
| 244 | + @test all(x -> length(x.as) == length(x.bs) == N, chains) |
| 245 | + |
| 246 | + # Test some statistical properties. |
| 247 | + @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) |
| 248 | + @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) |
| 249 | + @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) |
| 250 | + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) |
| 251 | + |
| 252 | + # Test reproducibility. |
| 253 | + Random.seed!(1234) |
| 254 | + chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; |
| 255 | + chain_type = MyChain) |
| 256 | + |
| 257 | + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) |
| 258 | + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) |
| 259 | + |
| 260 | + # Unexpected order of arguments. |
| 261 | + str = "Number of chains (10) is greater than number of samples per chain (5)" |
| 262 | + @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), |
| 263 | + MCMCSerial(), 5, 10; |
| 264 | + chain_type = MyChain) |
| 265 | + |
| 266 | + # Suppress output. |
| 267 | + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do |
| 268 | + sample(MyModel(), MySampler(), MCMCSerial(), 10_000, 100; |
| 269 | + progress = false, chain_type = MyChain) |
| 270 | + end |
| 271 | + @test all(l.level > Logging.LogLevel(-1) for l in logs) |
| 272 | + end |
| 273 | + |
228 | 274 | @testset "Chain constructors" begin
|
229 | 275 | chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
|
230 | 276 | chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)
|
|
0 commit comments