|
1 | 1 | using AdvancedMH |
| 2 | +using AbstractMCMC |
2 | 3 | using DiffResults |
3 | 4 | using Distributions |
4 | 5 | using ForwardDiff |
@@ -33,6 +34,25 @@ include("util.jl") |
33 | 34 | LogDensityProblems.logdensity(::typeof(density), θ) = density(θ) |
34 | 35 | LogDensityProblems.dimension(::typeof(density)) = 2 |
35 | 36 |
|
| 37 | + @testset "getparams/setparams!! (AbstractMCMC interface)" begin |
| 38 | + t1, _ = AbstractMCMC.step(Random.default_rng(), model, StaticMH([Normal(0, 1), Normal(0, 1)])) |
| 39 | + t2, _ = AbstractMCMC.step(Random.default_rng(), model, MALA(x -> MvNormal(x, I)); initial_params=ones(2)) |
| 40 | + for t in [t1, t2] |
| 41 | + @test AbstractMCMC.getparams(model, t) == t.params |
| 42 | + |
| 43 | + new_transition = AbstractMCMC.setparams!!(model, t, AbstractMCMC.getparams(model, t)) |
| 44 | + @test new_transition.lp == t.lp |
| 45 | + @test new_transition.accepted == t.accepted |
| 46 | + @test new_transition.params == t.params |
| 47 | + if hasfield(typeof(t), :gradient) |
| 48 | + @test new_transition.gradient == t.gradient |
| 49 | + end |
| 50 | + |
| 51 | + t_replaced = AbstractMCMC.setparams!!(model, t, [1.0, 2.0]) |
| 52 | + @test t_replaced.params == [1.0, 2.0] |
| 53 | + end |
| 54 | + end |
| 55 | + |
36 | 56 | @testset "StaticMH" begin |
37 | 57 | # Set up our sampler with initial parameters. |
38 | 58 | spl1 = StaticMH([Normal(0,1), Normal(0, 1)]) |
@@ -69,7 +89,7 @@ include("util.jl") |
69 | 89 | @test mean(chain1.σ) ≈ 1.0 atol=0.1 |
70 | 90 | @test mean(chain2.μ) ≈ 0.0 atol=0.1 |
71 | 91 | @test mean(chain2.σ) ≈ 1.0 atol=0.1 |
72 | | - @test mean(chain3.μ) ≈ 0.0 atol=0.1 |
| 92 | + @test mean(chain3.μ) ≈ 0.0 atol=0.15 |
73 | 93 | @test mean(chain3.σ) ≈ 1.0 atol=0.1 |
74 | 94 | end |
75 | 95 |
|
|
0 commit comments