|
251 | 251 | )
|
252 | 252 | end
|
253 | 253 |
|
| 254 | +@testset "Equivalence of RepeatSampler and repeating Sampler" begin |
| 255 | + sampler1 = Gibbs(@varname(s) => RepeatSampler(MH(), 3), @varname(m) => ESS()) |
| 256 | + sampler2 = Gibbs( |
| 257 | + @varname(s) => MH(), @varname(s) => MH(), @varname(s) => MH(), @varname(m) => ESS() |
| 258 | + ) |
| 259 | + Random.seed!(23) |
| 260 | + chain1 = sample(gdemo_default, sampler1, 10) |
| 261 | + Random.seed!(23) |
| 262 | + chain2 = sample(gdemo_default, sampler1, 10) |
| 263 | + @test chain1.value == chain2.value |
| 264 | +end |
| 265 | + |
254 | 266 | @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
|
255 | 267 | @testset "Deprecated Gibbs constructors" begin
|
256 | 268 | N = 10
|
|
302 | 314 | vnm = @varname(m)
|
303 | 315 | Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg)
|
304 | 316 | end
|
305 |
| - for s in (s1, s2, s3, s4, s5, s6, s7, s8) |
| 317 | + # Same thing but using RepeatSampler. |
| 318 | + s9 = Gibbs( |
| 319 | + @varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3), |
| 320 | + @varname(m) => RepeatSampler(PG(10), 2), |
| 321 | + ) |
| 322 | + for s in (s1, s2, s3, s4, s5, s6, s7, s8, s9) |
306 | 323 | @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs"
|
307 | 324 | end
|
308 | 325 |
|
|
314 | 331 | sample(gdemo_default, s6, N)
|
315 | 332 | sample(gdemo_default, s7, N)
|
316 | 333 | sample(gdemo_default, s8, N)
|
| 334 | + sample(gdemo_default, s9, N) |
317 | 335 |
|
318 | 336 | g = Turing.Sampler(s3, gdemo_default)
|
319 | 337 | @test sample(gdemo_default, g, N) isa MCMCChains.Chains
|
|
355 | 373 | @varname(s) => MH(),
|
356 | 374 | (@varname(s), @varname(m)) => MH(),
|
357 | 375 | @varname(m) => ESS(),
|
358 |
| - @varname(s) => MH(), |
| 376 | + @varname(s) => RepeatSampler(MH(), 3), |
359 | 377 | @varname(m) => HMC(0.2, 4; adtype=adbackend),
|
360 | 378 | (@varname(m), @varname(s)) => HMC(0.2, 4; adtype=adbackend),
|
361 | 379 | )
|
|
367 | 385 | (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => PG(15),
|
368 | 386 | (@varname(z1), @varname(z2)) => PG(15),
|
369 | 387 | (@varname(mu1), @varname(mu2)) => HMC(0.15, 3; adtype=adbackend),
|
370 |
| - (@varname(z3), @varname(z4)) => PG(15), |
| 388 | + (@varname(z3), @varname(z4)) => RepeatSampler(PG(15), 2), |
371 | 389 | (@varname(mu1)) => ESS(),
|
372 | 390 | (@varname(mu2)) => ESS(),
|
373 | 391 | (@varname(z1), @varname(z2)) => PG(15),
|
|
0 commit comments