Skip to content

Commit cd5136f

Browse files
torfjeldeJaimeRZPgithub-actions[bot]
authored
Minor improvements to constructors and related tests (#336)
* made constructor tests a bit nicer * added proper promotion in HMCDA constructor * renamed get_type_of_spl to sampler_eltype * further improvements to tests * format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Jaime RZ <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 00ac8b8 commit cd5136f

File tree

3 files changed

+77
-90
lines changed

3 files changed

+77
-90
lines changed

src/abstractmcmc.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ struct HMCState{
2626
adaptor::TAdapt
2727
end
2828

29+
getadaptor(state::HMCState) = state.adaptor
30+
getmetric(state::HMCState) = state.metric
31+
32+
getintegrator(state::HMCState) = state.κ.τ.integrator
33+
getintegrator(state::HMCState) = state.κ.τ.integrator
34+
2935
"""
3036
$(TYPEDSIGNATURES)
3137
@@ -248,14 +254,14 @@ end
248254
### Utils ###
249255
#############
250256

251-
function get_type_of_spl(::AbstractHMCSampler{T}) where {T<:Real}
257+
function sampler_eltype(::AbstractHMCSampler{T}) where {T<:Real}
252258
return T
253259
end
254260

255261
#########
256262

257263
function make_init_params(spl::AbstractHMCSampler, logdensity, init_params)
258-
T = get_type_of_spl(spl)
264+
T = sampler_eltype(spl)
259265
if init_params == nothing
260266
d = LogDensityProblems.dimension(logdensity)
261267
init_params = randn(rng, d)
@@ -274,7 +280,7 @@ function make_step_size(
274280
ϵ = spl.init_ϵ
275281
if iszero(ϵ)
276282
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
277-
T = get_type_of_spl(spl)
283+
T = sampler_eltype(spl)
278284
ϵ = T(ϵ)
279285
@info string("Found initial step size ", ϵ)
280286
end
@@ -312,7 +318,7 @@ make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)
312318

313319
function make_metric(spl::AbstractHMCSampler, logdensity)
314320
d = LogDensityProblems.dimension(logdensity)
315-
T = get_type_of_spl(spl)
321+
T = sampler_eltype(spl)
316322
return make_metric(spl.metric, T, d)
317323
end
318324

src/constructors.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
152152
metric::Union{Symbol,AbstractMetric}
153153
end
154154

155-
function HMCDA(δ, λ; init_ϵ = 0.0, integrator = :leapfrog, metric = :diagonal)
156-
if typeof(δ) != typeof(λ)
157-
@warn "typeof(δ) != typeof(λ) --> using typeof(δ)"
158-
end
155+
function HMCDA(δ, λ; init_ϵ = 0, integrator = :leapfrog, metric = :diagonal)
156+
δ, λ = promote(δ, λ)
159157
T = typeof(δ)
160158
return HMCDA(δ, T(λ), T(init_ϵ), integrator, metric)
161159
end

test/constructors.jl

Lines changed: 65 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,74 @@
11
using AdvancedHMC, AbstractMCMC, Random
22
include("common.jl")
33

4-
# Initalize samplers
5-
nuts = NUTS(0.8)
6-
nuts_32 = NUTS(0.8f0)
7-
hmc = HMC(0.1, 25)
8-
hmcda = HMCDA(0.8, 1.0)
9-
hmcda_32 = HMCDA(0.8f0, 1.0)
10-
11-
integrator = Leapfrog(1e-3)
12-
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
13-
metric = DiagEuclideanMetric(2)
14-
adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator)
15-
custom = HMCSampler(kernel, metric, adaptor)
16-
17-
# Check that everything is initalized correctly
184
@testset "Constructors" begin
19-
# Types
20-
@test typeof(nuts) == NUTS{Float64}
21-
@test typeof(nuts_32) == NUTS{Float32}
22-
@test typeof(hmc) == HMC{Float64}
23-
@test typeof(hmcda) == HMCDA{Float64}
24-
@test typeof(nuts) <: AdvancedHMC.AbstractHMCSampler
25-
@test typeof(nuts) <: AbstractMCMC.AbstractSampler
26-
27-
# NUTS
28-
@test nuts.δ == 0.8
29-
@test nuts.max_depth == 10
30-
@test nuts.Δ_max == 1000.0
31-
@test nuts.init_ϵ == 0.0
32-
@test nuts.integrator == :leapfrog
33-
@test nuts.metric == :diagonal
34-
35-
# NUTS Float32
36-
@test nuts_32.δ == 0.8f0
37-
@test nuts_32.max_depth == 10
38-
@test nuts_32.Δ_max == 1000.0f0
39-
@test nuts_32.init_ϵ == 0.0f0
40-
41-
# HMC
42-
@test hmc.n_leapfrog == 25
43-
@test hmc.init_ϵ == 0.1
44-
@test hmc.integrator == :leapfrog
45-
@test hmc.metric == :diagonal
46-
47-
# HMCDA
48-
@test hmcda.δ == 0.8
49-
@test hmcda.λ == 1.0
50-
@test hmcda.init_ϵ == 0.0
51-
@test hmcda.integrator == :leapfrog
52-
@test hmcda.metric == :diagonal
53-
54-
# HMCDA Float32
55-
@test hmcda_32.δ == 0.8f0
56-
@test hmcda_32.λ == 1.0f0
57-
@test hmcda_32.init_ϵ == 0.0f0
58-
end
59-
60-
@testset "First step" begin
61-
rng = MersenneTwister(0)
62-
θ_init = randn(rng, 2)
63-
logdensitymodel = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
64-
_, nuts_state =
65-
AbstractMCMC.step(rng, logdensitymodel, nuts; n_adapts = 0, init_params = θ_init)
66-
_, hmc_state =
67-
AbstractMCMC.step(rng, logdensitymodel, hmc; n_adapts = 0, init_params = θ_init)
68-
_, nuts_32_state =
69-
AbstractMCMC.step(rng, logdensitymodel, nuts_32; n_adapts = 0, init_params = θ_init)
70-
_, custom_state =
71-
AbstractMCMC.step(rng, logdensitymodel, custom; n_adapts = 0, init_params = θ_init)
5+
θ_init = randn(2)
6+
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
727

73-
# Metric
74-
@test typeof(nuts_state.metric) == DiagEuclideanMetric{Float64,Vector{Float64}}
75-
@test typeof(nuts_32_state.metric) == DiagEuclideanMetric{Float32,Vector{Float32}}
76-
@test custom_state.metric == metric
8+
@testset "$T" for T in [Float32, Float64]
9+
@testset "$(nameof(typeof(sampler)))" for (sampler, expected) in [
10+
(
11+
HMC(T(0.1), 25),
12+
(
13+
adaptor_type = NoAdaptation,
14+
metric_type = DiagEuclideanMetric{T},
15+
integrator_type = Leapfrog{T},
16+
),
17+
),
18+
# This should peform the correct promotion for the 2nd argument.
19+
(
20+
HMCDA(T(0.1), 1),
21+
(
22+
adaptor_type = StanHMCAdaptor,
23+
metric_type = DiagEuclideanMetric{T},
24+
integrator_type = Leapfrog{T},
25+
),
26+
),
27+
(
28+
NUTS(T(0.8)),
29+
(
30+
adaptor_type = StanHMCAdaptor,
31+
metric_type = DiagEuclideanMetric{T},
32+
integrator_type = Leapfrog{T},
33+
),
34+
),
35+
(
36+
NUTS(T(0.8); metric = :unit),
37+
(
38+
adaptor_type = StanHMCAdaptor,
39+
metric_type = UnitEuclideanMetric{T},
40+
integrator_type = Leapfrog{T},
41+
),
42+
),
43+
(
44+
NUTS(T(0.8); metric = :dense),
45+
(
46+
adaptor_type = StanHMCAdaptor,
47+
metric_type = DenseEuclideanMetric{T},
48+
integrator_type = Leapfrog{T},
49+
),
50+
),
51+
]
52+
# Make sure the sampler element type is preserved.
53+
@test AdvancedHMC.sampler_eltype(sampler) == T
7754

78-
# Integrator
79-
@test typeof(nuts_state.κ.τ.integrator) == Leapfrog{Float64}
80-
@test typeof(nuts_32_state.κ.τ.integrator) == Leapfrog{Float32}
81-
@test custom_state.κ.τ.integrator == integrator
55+
# Step.
56+
rng = Random.default_rng()
57+
transition, state =
58+
AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init)
8259

83-
# Kernel
84-
@test nuts_state.κ == AdvancedHMC.make_kernel(nuts, nuts_state.κ.τ.integrator)
85-
@test custom_state.κ == kernel
60+
# Verify that the types are preserved in the transition.
61+
@test eltype(transition.z.θ) == T
62+
@test eltype(transition.z.r) == T
63+
@test eltype(transition.z.ℓπ.value) == T
64+
@test eltype(transition.z.ℓπ.gradient) == T
65+
@test eltype(transition.z.ℓκ.value) == T
66+
@test eltype(transition.z.ℓκ.gradient) == T
8667

87-
# Adaptor
88-
@test typeof(nuts_state.adaptor) <: StanHMCAdaptor
89-
@test hmc_state.adaptor == NoAdaptation()
90-
@test custom_state.adaptor == adaptor
68+
# Verify that the state is what we expect.
69+
@test AdvancedHMC.getmetric(state) isa expected.metric_type
70+
@test AdvancedHMC.getintegrator(state) isa expected.integrator_type
71+
@test AdvancedHMC.getadaptor(state) isa expected.adaptor_type
72+
end
73+
end
9174
end

0 commit comments

Comments
 (0)