|
1 | 1 | using AdvancedHMC, AbstractMCMC, Random
|
2 | 2 | include("common.jl")
|
3 | 3 |
|
4 |
| -# Initalize samplers |
5 |
| -nuts = NUTS(0.8) |
6 |
| -nuts_32 = NUTS(0.8f0) |
7 |
| -hmc = HMC(0.1, 25) |
8 |
| -hmcda = HMCDA(0.8, 1.0) |
9 |
| -hmcda_32 = HMCDA(0.8f0, 1.0) |
10 |
| - |
11 |
| -integrator = Leapfrog(1e-3) |
12 |
| -kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) |
13 |
| -metric = DiagEuclideanMetric(2) |
14 |
| -adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator) |
15 |
| -custom = HMCSampler(kernel, metric, adaptor) |
16 |
| - |
17 |
| -# Check that everything is initalized correctly |
18 | 4 | @testset "Constructors" begin
|
19 |
| - # Types |
20 |
| - @test typeof(nuts) == NUTS{Float64} |
21 |
| - @test typeof(nuts_32) == NUTS{Float32} |
22 |
| - @test typeof(hmc) == HMC{Float64} |
23 |
| - @test typeof(hmcda) == HMCDA{Float64} |
24 |
| - @test typeof(nuts) <: AdvancedHMC.AbstractHMCSampler |
25 |
| - @test typeof(nuts) <: AbstractMCMC.AbstractSampler |
26 |
| - |
27 |
| - # NUTS |
28 |
| - @test nuts.δ == 0.8 |
29 |
| - @test nuts.max_depth == 10 |
30 |
| - @test nuts.Δ_max == 1000.0 |
31 |
| - @test nuts.init_ϵ == 0.0 |
32 |
| - @test nuts.integrator == :leapfrog |
33 |
| - @test nuts.metric == :diagonal |
34 |
| - |
35 |
| - # NUTS Float32 |
36 |
| - @test nuts_32.δ == 0.8f0 |
37 |
| - @test nuts_32.max_depth == 10 |
38 |
| - @test nuts_32.Δ_max == 1000.0f0 |
39 |
| - @test nuts_32.init_ϵ == 0.0f0 |
40 |
| - |
41 |
| - # HMC |
42 |
| - @test hmc.n_leapfrog == 25 |
43 |
| - @test hmc.init_ϵ == 0.1 |
44 |
| - @test hmc.integrator == :leapfrog |
45 |
| - @test hmc.metric == :diagonal |
46 |
| - |
47 |
| - # HMCDA |
48 |
| - @test hmcda.δ == 0.8 |
49 |
| - @test hmcda.λ == 1.0 |
50 |
| - @test hmcda.init_ϵ == 0.0 |
51 |
| - @test hmcda.integrator == :leapfrog |
52 |
| - @test hmcda.metric == :diagonal |
53 |
| - |
54 |
| - # HMCDA Float32 |
55 |
| - @test hmcda_32.δ == 0.8f0 |
56 |
| - @test hmcda_32.λ == 1.0f0 |
57 |
| - @test hmcda_32.init_ϵ == 0.0f0 |
58 |
| -end |
59 |
| - |
60 |
| -@testset "First step" begin |
61 |
| - rng = MersenneTwister(0) |
62 |
| - θ_init = randn(rng, 2) |
63 |
| - logdensitymodel = AbstractMCMC.LogDensityModel(ℓπ_gdemo) |
64 |
| - _, nuts_state = |
65 |
| - AbstractMCMC.step(rng, logdensitymodel, nuts; n_adapts = 0, init_params = θ_init) |
66 |
| - _, hmc_state = |
67 |
| - AbstractMCMC.step(rng, logdensitymodel, hmc; n_adapts = 0, init_params = θ_init) |
68 |
| - _, nuts_32_state = |
69 |
| - AbstractMCMC.step(rng, logdensitymodel, nuts_32; n_adapts = 0, init_params = θ_init) |
70 |
| - _, custom_state = |
71 |
| - AbstractMCMC.step(rng, logdensitymodel, custom; n_adapts = 0, init_params = θ_init) |
| 5 | + θ_init = randn(2) |
| 6 | + model = AbstractMCMC.LogDensityModel(ℓπ_gdemo) |
72 | 7 |
|
73 |
| - # Metric |
74 |
| - @test typeof(nuts_state.metric) == DiagEuclideanMetric{Float64,Vector{Float64}} |
75 |
| - @test typeof(nuts_32_state.metric) == DiagEuclideanMetric{Float32,Vector{Float32}} |
76 |
| - @test custom_state.metric == metric |
| 8 | + @testset "$T" for T in [Float32, Float64] |
| 9 | + @testset "$(nameof(typeof(sampler)))" for (sampler, expected) in [ |
| 10 | + ( |
| 11 | + HMC(T(0.1), 25), |
| 12 | + ( |
| 13 | + adaptor_type = NoAdaptation, |
| 14 | + metric_type = DiagEuclideanMetric{T}, |
| 15 | + integrator_type = Leapfrog{T}, |
| 16 | + ), |
| 17 | + ), |
| 18 | + # This should peform the correct promotion for the 2nd argument. |
| 19 | + ( |
| 20 | + HMCDA(T(0.1), 1), |
| 21 | + ( |
| 22 | + adaptor_type = StanHMCAdaptor, |
| 23 | + metric_type = DiagEuclideanMetric{T}, |
| 24 | + integrator_type = Leapfrog{T}, |
| 25 | + ), |
| 26 | + ), |
| 27 | + ( |
| 28 | + NUTS(T(0.8)), |
| 29 | + ( |
| 30 | + adaptor_type = StanHMCAdaptor, |
| 31 | + metric_type = DiagEuclideanMetric{T}, |
| 32 | + integrator_type = Leapfrog{T}, |
| 33 | + ), |
| 34 | + ), |
| 35 | + ( |
| 36 | + NUTS(T(0.8); metric = :unit), |
| 37 | + ( |
| 38 | + adaptor_type = StanHMCAdaptor, |
| 39 | + metric_type = UnitEuclideanMetric{T}, |
| 40 | + integrator_type = Leapfrog{T}, |
| 41 | + ), |
| 42 | + ), |
| 43 | + ( |
| 44 | + NUTS(T(0.8); metric = :dense), |
| 45 | + ( |
| 46 | + adaptor_type = StanHMCAdaptor, |
| 47 | + metric_type = DenseEuclideanMetric{T}, |
| 48 | + integrator_type = Leapfrog{T}, |
| 49 | + ), |
| 50 | + ), |
| 51 | + ] |
| 52 | + # Make sure the sampler element type is preserved. |
| 53 | + @test AdvancedHMC.sampler_eltype(sampler) == T |
77 | 54 |
|
78 |
| - # Integrator |
79 |
| - @test typeof(nuts_state.κ.τ.integrator) == Leapfrog{Float64} |
80 |
| - @test typeof(nuts_32_state.κ.τ.integrator) == Leapfrog{Float32} |
81 |
| - @test custom_state.κ.τ.integrator == integrator |
| 55 | + # Step. |
| 56 | + rng = Random.default_rng() |
| 57 | + transition, state = |
| 58 | + AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init) |
82 | 59 |
|
83 |
| - # Kernel |
84 |
| - @test nuts_state.κ == AdvancedHMC.make_kernel(nuts, nuts_state.κ.τ.integrator) |
85 |
| - @test custom_state.κ == kernel |
| 60 | + # Verify that the types are preserved in the transition. |
| 61 | + @test eltype(transition.z.θ) == T |
| 62 | + @test eltype(transition.z.r) == T |
| 63 | + @test eltype(transition.z.ℓπ.value) == T |
| 64 | + @test eltype(transition.z.ℓπ.gradient) == T |
| 65 | + @test eltype(transition.z.ℓκ.value) == T |
| 66 | + @test eltype(transition.z.ℓκ.gradient) == T |
86 | 67 |
|
87 |
| - # Adaptor |
88 |
| - @test typeof(nuts_state.adaptor) <: StanHMCAdaptor |
89 |
| - @test hmc_state.adaptor == NoAdaptation() |
90 |
| - @test custom_state.adaptor == adaptor |
| 68 | + # Verify that the state is what we expect. |
| 69 | + @test AdvancedHMC.getmetric(state) isa expected.metric_type |
| 70 | + @test AdvancedHMC.getintegrator(state) isa expected.integrator_type |
| 71 | + @test AdvancedHMC.getadaptor(state) isa expected.adaptor_type |
| 72 | + end |
| 73 | + end |
91 | 74 | end
|
0 commit comments