Skip to content

Commit e302429

Browse files
JaimeRZPyebaitorfjeldegithub-actions[bot]
authored
ReadMe for consctructors (#329)
* bug + docs * read me * format * Update abstractmcmc.jl * Apply suggestions from code review Co-authored-by: Hong Ge <[email protected]> * API * Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde <[email protected]> * Apply suggestions from code review Co-authored-by: Hong Ge <[email protected]> * Hong s comments * Fix typo and simplify arguments (#331) * Removed `n_adapts` from sampler constructors and some fixes. (#333) * Update README.md * Update README.md Co-authored-by: Tor Erlend Fjelde <[email protected]> --------- Co-authored-by: Jaime RZ <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> * Minor tweaks to the metric field comments. * Removed redundant make_metric function * Fix typos in constructor tests * More fixes. * Typofix. * More test fixes. * Update test/constructors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * no init_e (#335) * no init_e * format * bug * Bugfix. (#337) * Bugfix. * Update src/constructors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/abstractmcmc.jl * Update src/abstractmcmc.jl Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * format * docs update for init_e * Update constructors.jl * Update README.md * Update constructors.jl * More bugfixes. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/constructors.jl * Update src/constructors.jl * Update README.md --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent cd5136f commit e302429

File tree

6 files changed

+214
-84
lines changed

6 files changed

+214
-84
lines changed

README.md

Lines changed: 140 additions & 31 deletions
Large diffs are not rendered by default.

src/abstractmcmc.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ end
2828

2929
getadaptor(state::HMCState) = state.adaptor
3030
getmetric(state::HMCState) = state.metric
31-
32-
getintegrator(state::HMCState) = state.κ.τ.integrator
3331
getintegrator(state::HMCState) = state.κ.τ.integrator
3432

3533
"""
@@ -271,47 +269,63 @@ end
271269

272270
#########
273271

272+
function make_step_size(
273+
rng::Random.AbstractRNG,
274+
spl::HMCSampler,
275+
hamiltonian::Hamiltonian,
276+
init_params,
277+
)
278+
return spl.κ.τ.integrator.ϵ
279+
end
280+
274281
function make_step_size(
275282
rng::Random.AbstractRNG,
276283
spl::AbstractHMCSampler,
277284
hamiltonian::Hamiltonian,
278285
init_params,
279286
)
280-
ϵ = spl.init_ϵ
281-
if iszero(ϵ)
282-
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
283-
T = sampler_eltype(spl)
284-
ϵ = T(ϵ)
285-
@info string("Found initial step size ", ϵ)
286-
end
287-
return ϵ
287+
T = sampler_eltype(spl)
288+
return make_step_size(rng, spl.integrator, T, hamiltonian, init_params)
289+
288290
end
289291

290292
function make_step_size(
291293
rng::Random.AbstractRNG,
292-
spl::HMCSampler,
294+
integrator::AbstractIntegrator,
295+
T::Type,
293296
hamiltonian::Hamiltonian,
294297
init_params,
295298
)
296-
return spl.κ.τ.integrator.ϵ
299+
return integrator.ϵ
300+
end
301+
302+
function make_step_size(
303+
rng::Random.AbstractRNG,
304+
integrator::Symbol,
305+
T::Type,
306+
hamiltonian::Hamiltonian,
307+
init_params,
308+
)
309+
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
310+
@info string("Found initial step size ", ϵ)
311+
return T(ϵ)
297312
end
298313

299314
make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator
300315
make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ)
301316
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
302-
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i
303317
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
304-
make_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
318+
make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")
305319
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
306-
make_integrator(i::Val{:jitteredleapfrog}, ϵ::Real) = JitteredLeapfrog(ϵ)
307-
make_integrator(i::Val{:temperedleapfrog}, ϵ::Real) = TemperedLeapfrog(ϵ)
320+
make_integrator(i::Val{:jitteredleapfrog}, ϵ::T) where {T<:Real} =
321+
JitteredLeapfrog(ϵ, T(0.1ϵ))
322+
make_integrator(i::Val{:temperedleapfrog}, ϵ::T) where {T<:Real} = TemperedLeapfrog(ϵ, T(1))
308323

309324
#########
310325

311-
make_metric(i...) = error("Metric $(typeof(i)) not supported.")
326+
make_metric(@nospecialize(i), T::Type, d::Int) = error("Metric $(typeof(i)) not supported.")
312327
make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d)
313328
make_metric(i::AbstractMetric, T::Type, d::Int) = i
314-
make_metric(i::Type{AbstractMetric}, T::Type, d::Int) = i
315329
make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d)
316330
make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d)
317331
make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)

src/constructors.jl

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,18 @@ struct HMCSampler{T<:Real} <: AbstractHMCSampler{T}
2828
metric::AbstractMetric
2929
"[`AbstractAdaptor`](@ref)."
3030
adaptor::AbstractAdaptor
31-
"Adaptation steps if any"
32-
n_adapts::Int
3331
end
3432

35-
function HMCSampler(κ, metric, adaptor; n_adapts = 0)
33+
function HMCSampler(κ, metric, adaptor)
3634
T = collect(typeof(metric).parameters)[1]
37-
return HMCSampler{T}(κ, metric, adaptor, n_adapts)
35+
return HMCSampler{T}(κ, metric, adaptor)
3836
end
3937

4038
############
4139
### NUTS ###
4240
############
4341
"""
44-
NUTS(n_adapts::Int, δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0)
42+
NUTS(δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0, integrator = :leapfrog, metric = :diagonal)
4543
4644
No-U-Turn Sampler (NUTS) sampler.
4745
@@ -52,7 +50,7 @@ $(FIELDS)
5250
# Usage:
5351
5452
```julia
55-
NUTS(n_adapts=1000, δ=0.65) # Use 1000 adaption steps, and target accept ratio 0.65.
53+
NUTS(δ=0.65) # Use target accept ratio 0.65.
5654
```
5755
"""
5856
struct NUTS{T<:Real} <: AbstractHMCSampler{T}
@@ -62,24 +60,15 @@ struct NUTS{T<:Real} <: AbstractHMCSampler{T}
6260
max_depth::Int
6361
"Maximum divergence during doubling tree."
6462
Δ_max::T
65-
"Initial step size; 0 means it is automatically chosen."
66-
init_ϵ::T
6763
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
6864
integrator::Union{Symbol,AbstractIntegrator}
69-
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
65+
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
7066
metric::Union{Symbol,AbstractMetric}
7167
end
7268

73-
function NUTS(
74-
δ;
75-
max_depth = 10,
76-
Δ_max = 1000.0,
77-
init_ϵ = 0.0,
78-
integrator = :leapfrog,
79-
metric = :diagonal,
80-
)
69+
function NUTS(δ; max_depth = 10, Δ_max = 1000.0, integrator = :leapfrog, metric = :diagonal)
8170
T = typeof(δ)
82-
return NUTS(δ, max_depth, T(Δ_max), T(init_ϵ), integrator, metric)
71+
return NUTS(δ, max_depth, T(Δ_max), integrator, metric)
8372
end
8473

8574
###########
@@ -97,29 +86,32 @@ $(FIELDS)
9786
# Usage:
9887
9988
```julia
100-
HMC(init_ϵ=0.05, n_leapfrog=10)
89+
HMC(10, integrator = Leapfrog(0.05), metric = :diagonal)
10190
```
10291
"""
10392
struct HMC{T<:Real} <: AbstractHMCSampler{T}
104-
"Initial step size; 0 means automatically searching using a heuristic procedure."
105-
init_ϵ::T
10693
"Number of leapfrog steps."
10794
n_leapfrog::Int
10895
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
10996
integrator::Union{Symbol,AbstractIntegrator}
110-
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
97+
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
11198
metric::Union{Symbol,AbstractMetric}
11299
end
113100

114-
function HMC(init_ϵ, n_leapfrog; integrator = :leapfrog, metric = :diagonal)
115-
return HMC(init_ϵ, n_leapfrog, integrator, metric)
101+
function HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal)
102+
if integrator isa Symbol
103+
T = typeof(0.0) # current default float type
104+
else
105+
T = integrator_eltype(integrator)
106+
end
107+
return HMC{T}(n_leapfrog, integrator, metric)
116108
end
117109

118110
#############
119111
### HMCDA ###
120112
#############
121113
"""
122-
HMCDA(n_adapts::Int, δ::Real, λ::Real; ϵ::Real=0)
114+
HMCDA(δ::Real, λ::Real; ϵ::Real=0, integrator = :leapfrog, metric = :diagonal)
123115
124116
Hamiltonian Monte Carlo sampler with Dual Averaging algorithm.
125117
@@ -130,7 +122,7 @@ $(FIELDS)
130122
# Usage:
131123
132124
```julia
133-
HMCDA(n_adapts=200, δ=0.65, λ=0.3)
125+
HMCDA(δ=0.65, λ=0.3)
134126
```
135127
136128
For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1111.4246)):
@@ -144,16 +136,14 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
144136
δ::T
145137
"Target leapfrog length."
146138
λ::T
147-
"Initial step size; 0 means automatically searching using a heuristic procedure."
148-
init_ϵ::T
149139
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
150140
integrator::Union{Symbol,AbstractIntegrator}
151-
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
141+
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
152142
metric::Union{Symbol,AbstractMetric}
153143
end
154144

155145
function HMCDA(δ, λ; init_ϵ = 0, integrator = :leapfrog, metric = :diagonal)
156146
δ, λ = promote(δ, λ)
157147
T = typeof(δ)
158-
return HMCDA(δ, T(λ), T(init_ϵ), integrator, metric)
148+
return HMCDA(δ, T(λ), integrator, metric)
159149
end

src/integrator.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct Leapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
7070
ϵ::T
7171
end
7272
Base.show(io::IO, l::Leapfrog) = print(io, "Leapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)))")
73+
integrator_eltype(i::AbstractLeapfrog{T}) where {T<:AbstractFloat} = T
7374

7475
### Jittering
7576

@@ -131,7 +132,7 @@ function _jitter(
131132
lf::JitteredLeapfrog{FT,T},
132133
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}}
133134
ϵ = lf.ϵ0 .* (1 .+ lf.jitter .* (2 .* rand(rng) .- 1))
134-
return @set lf.ϵ = ϵ
135+
return @set lf.ϵ = FT.(ϵ)
135136
end
136137

137138
jitter(rng::AbstractRNG, lf::JitteredLeapfrog) = _jitter(rng, lf)

test/abstractmcmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ include("common.jl")
99
θ_init = randn(rng, 2)
1010

1111
nuts = NUTS(0.8)
12-
hmc = HMC(0.05, 100)
12+
hmc = HMC(100; integrator = Leapfrog(0.05))
1313
hmcda = HMCDA(0.8, 0.1)
1414

1515
integrator = Leapfrog(1e-3)

test/constructors.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ include("common.jl")
88
@testset "$T" for T in [Float32, Float64]
99
@testset "$(nameof(typeof(sampler)))" for (sampler, expected) in [
1010
(
11-
HMC(T(0.1), 25),
11+
HMC(25, integrator = Leapfrog(T(0.1))),
1212
(
1313
adaptor_type = NoAdaptation,
1414
metric_type = DiagEuclideanMetric{T},
1515
integrator_type = Leapfrog{T},
1616
),
1717
),
18-
# This should peform the correct promotion for the 2nd argument.
18+
# This should perform the correct promotion for the 2nd argument.
1919
(
20-
HMCDA(T(0.1), 1),
20+
HMCDA(T(0.8), 1, integrator = Leapfrog(T(0.1))),
2121
(
2222
adaptor_type = StanHMCAdaptor,
2323
metric_type = DiagEuclideanMetric{T},
@@ -48,6 +48,22 @@ include("common.jl")
4848
integrator_type = Leapfrog{T},
4949
),
5050
),
51+
(
52+
NUTS(T(0.8); integrator = :jitteredleapfrog),
53+
(
54+
adaptor_type = StanHMCAdaptor,
55+
metric_type = DiagEuclideanMetric{T},
56+
integrator_type = JitteredLeapfrog{T,T},
57+
),
58+
),
59+
(
60+
NUTS(T(0.8); integrator = :temperedleapfrog),
61+
(
62+
adaptor_type = StanHMCAdaptor,
63+
metric_type = DiagEuclideanMetric{T},
64+
integrator_type = TemperedLeapfrog{T,T},
65+
),
66+
),
5167
]
5268
# Make sure the sampler element type is preserved.
5369
@test AdvancedHMC.sampler_eltype(sampler) == T

0 commit comments

Comments
 (0)