Skip to content

Commit 8429077

Browse files
torfjeldeJaimeRZPgithub-actions[bot]yebai
authored
Further improvements to recently introduced constructors (#340)
* remove type parameter from AbstractHMCSampler, and added eltype for metrics * Update src/constructors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added determine_sampler_eltype to unify handling of different argument types * fixed issues with conversion of arguments * added test for type promotion in the case of HMCDA * removed unnecessary float calls * make sampler types concretely typed * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * remove now-redundant type parameter from HMC * removed unused argument to HMCDA * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/constructors.jl Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: Jaime RZ <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent 37481ac commit 8429077

File tree

5 files changed

+98
-38
lines changed

5 files changed

+98
-38
lines changed

src/AdvancedHMC.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ using AbstractMCMC: LogDensityModel
2626

2727
import StatsBase: sample
2828

29+
const DEFAULT_FLOAT_TYPE = typeof(float(0))
30+
2931
include("utilities.jl")
3032

3133
# Notations

src/abstractmcmc.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,6 @@ end
251251
#############
252252
### Utils ###
253253
#############
254-
255-
function sampler_eltype(::AbstractHMCSampler{T}) where {T<:Real}
256-
return T
257-
end
258-
259-
#########
260-
261254
function make_init_params(spl::AbstractHMCSampler, logdensity, init_params)
262255
T = sampler_eltype(spl)
263256
if init_params == nothing

src/constructors.jl

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,34 @@
1-
abstract type AbstractHMCSampler{T<:Real} <: AbstractMCMC.AbstractSampler end
1+
"""
2+
determine_sampler_eltype(xs...)
3+
4+
Determine the element type to use for the given arguments.
5+
6+
Symbols are either resolved to the default float type or simply dropped
7+
in favour of determined types from the other arguments.
8+
"""
9+
determine_sampler_eltype(xs...) = float(_determine_sampler_eltype(xs...))
10+
# NOTE: We want to defer conversion to `float` until the very "end" of the
11+
# process, so as to allow `promote_type` to do it's job properly.
12+
# For example, in the scenario `determine_sampler_eltype(::Int64, ::Float32)`
13+
# we want to return `Float32`, not `Float64`. The latter would occur
14+
# if we did `float(eltype(x))` instead of just `eltype(x)`.
15+
_determine_sampler_eltype(x) = eltype(x)
16+
_determine_sampler_eltype(x::AbstractIntegrator) = integrator_eltype(x)
17+
_determine_sampler_eltype(::Symbol) = DEFAULT_FLOAT_TYPE
18+
function _determine_sampler_eltype(xs...)
19+
xs_not_symbol = filter(!Base.Fix2(isa, Symbol), xs)
20+
isempty(xs_not_symbol) && return DEFAULT_FLOAT_TYPE
21+
return promote_type(map(_determine_sampler_eltype, xs_not_symbol)...)
22+
end
23+
24+
abstract type AbstractHMCSampler <: AbstractMCMC.AbstractSampler end
25+
26+
"""
27+
sampler_eltype(sampler)
28+
29+
Return the element type of the sampler.
30+
"""
31+
function sampler_eltype end
232

333
##############
434
### Custom ###
@@ -21,19 +51,17 @@ and `adaptor` after sampling.
2151
2252
To access the updated fields use the resulting [`HMCState`](@ref).
2353
"""
24-
struct HMCSampler{T<:Real} <: AbstractHMCSampler{T}
54+
struct HMCSampler{K<:AbstractMCMCKernel,M<:AbstractMetric,A<:AbstractAdaptor} <:
55+
AbstractHMCSampler
2556
"[`AbstractMCMCKernel`](@ref)."
26-
κ::AbstractMCMCKernel
57+
κ::K
2758
"Choice of initial metric [`AbstractMetric`](@ref). The metric type will be preserved during adaption."
28-
metric::AbstractMetric
59+
metric::M
2960
"[`AbstractAdaptor`](@ref)."
30-
adaptor::AbstractAdaptor
61+
adaptor::A
3162
end
3263

33-
function HMCSampler(κ, metric, adaptor)
34-
T = collect(typeof(metric).parameters)[1]
35-
return HMCSampler{T}(κ, metric, adaptor)
36-
end
64+
sampler_eltype(sampler::HMCSampler) = eltype(sampler.metric)
3765

3866
############
3967
### NUTS ###
@@ -53,24 +81,27 @@ $(FIELDS)
5381
NUTS(δ=0.65) # Use target accept ratio 0.65.
5482
```
5583
"""
56-
struct NUTS{T<:Real} <: AbstractHMCSampler{T}
84+
struct NUTS{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
85+
AbstractHMCSampler
5786
"Target acceptance rate for dual averaging."
5887
δ::T
5988
"Maximum doubling tree depth."
6089
max_depth::Int
6190
"Maximum divergence during doubling tree."
6291
Δ_max::T
6392
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
64-
integrator::Union{Symbol,AbstractIntegrator}
93+
integrator::I
6594
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
66-
metric::Union{Symbol,AbstractMetric}
95+
metric::M
6796
end
6897

6998
function NUTS(δ; max_depth = 10, Δ_max = 1000.0, integrator = :leapfrog, metric = :diagonal)
70-
T = typeof)
71-
return NUTS(δ, max_depth, T(Δ_max), integrator, metric)
99+
T = determine_sampler_eltype(δ, integrator, metric)
100+
return NUTS(T(δ), max_depth, T(Δ_max), integrator, metric)
72101
end
73102

103+
sampler_eltype(::NUTS{T}) where {T} = T
104+
74105
###########
75106
### HMC ###
76107
###########
@@ -89,23 +120,20 @@ $(FIELDS)
89120
HMC(10, integrator = Leapfrog(0.05), metric = :diagonal)
90121
```
91122
"""
92-
struct HMC{T<:Real} <: AbstractHMCSampler{T}
123+
struct HMC{I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
124+
AbstractHMCSampler
93125
"Number of leapfrog steps."
94126
n_leapfrog::Int
95127
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
96-
integrator::Union{Symbol,AbstractIntegrator}
128+
integrator::I
97129
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
98-
metric::Union{Symbol,AbstractMetric}
130+
metric::M
99131
end
100132

101-
function HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal)
102-
if integrator isa Symbol
103-
T = typeof(0.0) # current default float type
104-
else
105-
T = integrator_eltype(integrator)
106-
end
107-
return HMC{T}(n_leapfrog, integrator, metric)
108-
end
133+
HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal) =
134+
HMC(n_leapfrog, integrator, metric)
135+
136+
sampler_eltype(sampler::HMC) = determine_sampler_eltype(sampler.metric, sampler.integrator)
109137

110138
#############
111139
### HMCDA ###
@@ -131,7 +159,7 @@ For more information, please view the following paper ([arXiv link](https://arxi
131159
setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning
132160
Research 15, no. 1 (2014): 1593-1623.
133161
"""
134-
struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
162+
struct HMCDA{T<:Real} <: AbstractHMCSampler
135163
"Target acceptance rate for dual averaging."
136164
δ::T
137165
"Target leapfrog length."
@@ -142,8 +170,9 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
142170
metric::Union{Symbol,AbstractMetric}
143171
end
144172

145-
function HMCDA(δ, λ; init_ϵ = 0, integrator = :leapfrog, metric = :diagonal)
146-
δ, λ = promote(δ, λ)
147-
T = typeof(δ)
148-
return HMCDA(δ, T(λ), integrator, metric)
173+
function HMCDA(δ, λ; integrator = :leapfrog, metric = :diagonal)
174+
T = determine_sampler_eltype(δ, λ, integrator, metric)
175+
return HMCDA(T(δ), T(λ), integrator, metric)
149176
end
177+
178+
sampler_eltype(::HMCDA{T}) where {T} = T

src/metric.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ UnitEuclideanMetric(dim::Int) = UnitEuclideanMetric(Float64, (dim,))
2323

2424
renew(ue::UnitEuclideanMetric, M⁻¹) = UnitEuclideanMetric(M⁻¹, ue.size)
2525

26+
Base.eltype(::UnitEuclideanMetric{T}) where {T} = T
2627
Base.size(e::UnitEuclideanMetric) = e.size
2728
Base.size(e::UnitEuclideanMetric, dim::Int) = e.size[dim]
2829
Base.show(io::IO, uem::UnitEuclideanMetric) =
@@ -47,6 +48,7 @@ DiagEuclideanMetric(dim::Int) = DiagEuclideanMetric(Float64, dim)
4748

4849
renew(ue::DiagEuclideanMetric, M⁻¹) = DiagEuclideanMetric(M⁻¹)
4950

51+
Base.eltype(::DiagEuclideanMetric{T}) where {T} = T
5052
Base.size(e::DiagEuclideanMetric, dim...) = size(e.M⁻¹, dim...)
5153
Base.show(io::IO, dem::DiagEuclideanMetric) =
5254
print(io, "DiagEuclideanMetric($(_string_M⁻¹(dem.M⁻¹)))")
@@ -80,6 +82,7 @@ DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz)
8082

8183
renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)
8284

85+
Base.eltype(::DenseEuclideanMetric{T}) where {T} = T
8386
Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...)
8487
Base.show(io::IO, dem::DenseEuclideanMetric) =
8588
print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))")

test/constructors.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using AdvancedHMC, AbstractMCMC, Random
22
include("common.jl")
33

44
@testset "Constructors" begin
5-
θ_init = randn(2)
5+
d = 2
6+
θ_init = randn(d)
67
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
78

89
@testset "$T" for T in [Float32, Float64]
@@ -15,6 +16,38 @@ include("common.jl")
1516
integrator_type = Leapfrog{T},
1617
),
1718
),
19+
(
20+
HMC(25, metric = DiagEuclideanMetric(ones(T, 2))),
21+
(
22+
adaptor_type = NoAdaptation,
23+
metric_type = DiagEuclideanMetric{T},
24+
integrator_type = Leapfrog{T},
25+
),
26+
),
27+
(
28+
HMC(25, integrator = Leapfrog(T(0.1)), metric = :unit),
29+
(
30+
adaptor_type = NoAdaptation,
31+
metric_type = UnitEuclideanMetric{T},
32+
integrator_type = Leapfrog{T},
33+
),
34+
),
35+
(
36+
HMC(25, integrator = Leapfrog(T(0.1)), metric = :dense),
37+
(
38+
adaptor_type = NoAdaptation,
39+
metric_type = DenseEuclideanMetric{T},
40+
integrator_type = Leapfrog{T},
41+
),
42+
),
43+
(
44+
HMCDA(T(0.8), one(T), integrator = Leapfrog(T(0.1))),
45+
(
46+
adaptor_type = NesterovDualAveraging,
47+
metric_type = DiagEuclideanMetric{T},
48+
integrator_type = Leapfrog{T},
49+
),
50+
),
1851
# This should perform the correct promotion for the 2nd argument.
1952
(
2053
HMCDA(T(0.8), 1, integrator = Leapfrog(T(0.1))),

0 commit comments

Comments
 (0)