Skip to content

Commit da87eb4

Browse files
authored
kwarg nadapt (#332)
* kwarg nadapt * tests up * format * bug * bug
1 parent a7edfa9 commit da87eb4

File tree

6 files changed

+28
-35
lines changed

6 files changed

+28
-35
lines changed

src/abstractmcmc.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ function AbstractMCMC.sample(
3737
model::LogDensityModel,
3838
sampler::AbstractHMCSampler,
3939
N::Integer;
40+
n_adapts::Int = min(div(N, 10), 1_000),
4041
progress = true,
4142
verbose = false,
4243
callback = nothing,
@@ -52,6 +53,7 @@ function AbstractMCMC.sample(
5253
model,
5354
sampler,
5455
N;
56+
n_adapts = n_adapts,
5557
progress = progress,
5658
verbose = verbose,
5759
callback = callback,
@@ -66,6 +68,7 @@ function AbstractMCMC.sample(
6668
parallel::AbstractMCMC.AbstractMCMCEnsemble,
6769
N::Integer,
6870
nchains::Integer;
71+
n_adapts::Int = min(div(N, 10), 1_000),
6972
progress = true,
7073
verbose = false,
7174
callback = nothing,
@@ -84,6 +87,7 @@ function AbstractMCMC.sample(
8487
parallel,
8588
N,
8689
nchains;
90+
n_adapts = n_adapts,
8791
progress = progress,
8892
verbose = verbose,
8993
callback = callback,
@@ -150,7 +154,7 @@ function AbstractMCMC.step(
150154

151155
# Adapt h and spl.
152156
tstat = stat(t)
153-
n_adapts = get_nadapts(spl)
157+
n_adapts = kwargs[:n_adapts]
154158
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
155159
tstat = merge(tstat, (is_adapt = isadapted,))
156160

@@ -336,11 +340,6 @@ end
336340

337341
#########
338342

339-
get_nadapts(spl::Union{HMCSampler,NUTS,HMCDA}) = spl.n_adapts
340-
get_nadapts(spl::HMC) = 0
341-
342-
#########
343-
344343
function make_kernel(spl::NUTS, integrator::AbstractIntegrator)
345344
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
346345
end

src/constructors.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ NUTS(n_adapts=1000, δ=0.65) # Use 1000 adaption steps, and target accept ratio
5656
```
5757
"""
5858
struct NUTS{T<:Real} <: AbstractHMCSampler{T}
59-
"Number of adaptation steps."
60-
n_adapts::Int
6159
"Target acceptance rate for dual averaging."
6260
δ::T
6361
"Maximum doubling tree depth."
@@ -73,7 +71,6 @@ struct NUTS{T<:Real} <: AbstractHMCSampler{T}
7371
end
7472

7573
function NUTS(
76-
n_adapts,
7774
δ;
7875
max_depth = 10,
7976
Δ_max = 1000.0,
@@ -82,7 +79,7 @@ function NUTS(
8279
metric = :diagonal,
8380
)
8481
T = typeof(δ)
85-
return NUTS(n_adapts, δ, max_depth, T(Δ_max), T(init_ϵ), integrator, metric)
82+
return NUTS(δ, max_depth, T(Δ_max), T(init_ϵ), integrator, metric)
8683
end
8784

8885
###########
@@ -143,8 +140,6 @@ For more information, please view the following paper ([arXiv link](https://arxi
143140
Research 15, no. 1 (2014): 1593-1623.
144141
"""
145142
struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
146-
"`Number of adaptation steps."
147-
n_adapts::Int
148143
"Target acceptance rate for dual averaging."
149144
δ::T
150145
"Target leapfrog length."
@@ -157,10 +152,10 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
157152
metric::Union{Symbol,AbstractMetric}
158153
end
159154

160-
function HMCDA(n_adapts, δ, λ; init_ϵ = 0.0, integrator = :leapfrog, metric = :diagonal)
155+
function HMCDA(δ, λ; init_ϵ = 0.0, integrator = :leapfrog, metric = :diagonal)
161156
if typeof(δ) != typeof(λ)
162157
@warn "typeof(δ) != typeof(λ) --> using typeof(δ)"
163158
end
164159
T = typeof(δ)
165-
return HMCDA(n_adapts, δ, T(λ), T(init_ϵ), integrator, metric)
160+
return HMCDA(δ, T(λ), T(init_ϵ), integrator, metric)
166161
end

test/abstractmcmc.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ include("common.jl")
88
n_adapts = 5_000
99
θ_init = randn(rng, 2)
1010

11-
nuts = NUTS(n_adapts, 0.8)
11+
nuts = NUTS(0.8)
1212
hmc = HMC(0.05, 100)
13-
hmcda = HMCDA(n_adapts, 0.8, 0.1)
13+
hmcda = HMCDA(0.8, 0.1)
1414

1515
integrator = Leapfrog(1e-3)
1616
κ = AdvancedHMC.make_kernel(nuts, integrator)
@@ -27,7 +27,7 @@ include("common.jl")
2727
model,
2828
nuts,
2929
n_adapts + n_samples;
30-
nadapts = n_adapts,
30+
n_adapts = n_adapts,
3131
init_params = θ_init,
3232
progress = false,
3333
verbose = false,
@@ -50,7 +50,7 @@ include("common.jl")
5050
model,
5151
hmc,
5252
n_adapts + n_samples;
53-
nadapts = n_adapts,
53+
n_adapts = n_adapts,
5454
init_params = θ_init,
5555
progress = false,
5656
verbose = false,
@@ -73,7 +73,7 @@ include("common.jl")
7373
model,
7474
custom,
7575
n_adapts + n_samples;
76-
nadapts = n_adapts,
76+
n_adapts = 0,
7777
init_params = θ_init,
7878
progress = false,
7979
verbose = false,
@@ -99,7 +99,7 @@ include("common.jl")
9999
model,
100100
custom,
101101
10;
102-
nadapts = 0,
102+
n_adapts = 0,
103103
init_params = θ_init,
104104
progress = false,
105105
verbose = false,
@@ -109,7 +109,7 @@ include("common.jl")
109109
model,
110110
custom,
111111
10;
112-
nadapts = 0,
112+
n_adapts = 0,
113113
init_params = θ_init,
114114
progress = false,
115115
verbose = false,

test/adaptation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function runnuts(ℓπ, metric; n_samples = 3_000)
88
θ_init = rand(D)
99
rng = MersenneTwister(0)
1010

11-
nuts = NUTS(n_adapts, 0.8)
11+
nuts = NUTS(0.8)
1212
h = Hamiltonian(metric, ℓπ, ForwardDiff)
1313
step_size = AdvancedHMC.make_step_size(rng, nuts, h, θ_init)
1414
integrator = AdvancedHMC.make_integrator(nuts, step_size)

test/constructors.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ using AdvancedHMC, AbstractMCMC, Random
22
include("common.jl")
33

44
# Initalize samplers
5-
nuts = NUTS(1000, 0.8)
6-
nuts_32 = NUTS(1000, 0.8f0)
5+
nuts = NUTS(0.8)
6+
nuts_32 = NUTS(0.8f0)
77
hmc = HMC(0.1, 25)
8-
hmcda = HMCDA(1000, 0.8, 1.0)
9-
hmcda_32 = HMCDA(1000, 0.8f0, 1.0)
8+
hmcda = HMCDA(0.8, 1.0)
9+
hmcda_32 = HMCDA(0.8f0, 1.0)
1010

1111
integrator = Leapfrog(1e-3)
1212
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
@@ -25,7 +25,6 @@ custom = HMCSampler(kernel, metric, adaptor)
2525
@test typeof(nuts) <: AbstractMCMC.AbstractSampler
2626

2727
# NUTS
28-
@test nuts.n_adapts == 1000
2928
@test nuts.δ == 0.8
3029
@test nuts.max_depth == 10
3130
@test nuts.Δ_max == 1000.0
@@ -34,7 +33,6 @@ custom = HMCSampler(kernel, metric, adaptor)
3433
@test nuts.metric == :diagonal
3534

3635
# NUTS Float32
37-
@test nuts_32.n_adapts == 1000
3836
@test nuts_32.δ == 0.8f0
3937
@test nuts_32.max_depth == 10
4038
@test nuts_32.Δ_max == 1000.0f0
@@ -47,15 +45,13 @@ custom = HMCSampler(kernel, metric, adaptor)
4745
@test hmc.metric == :diagonal
4846

4947
# HMCDA
50-
@test hmcda.n_adapts == 1000
5148
@test hmcda.δ == 0.8
5249
@test hmcda.λ == 1.0
5350
@test hmcda.init_ϵ == 0.0
5451
@test hmcda.integrator == :leapfrog
5552
@test hmcda.metric == :diagonal
5653

5754
# HMCDA Float32
58-
@test hmcda_32.n_adapts == 1000
5955
@test hmcda_32.δ == 0.8f0
6056
@test hmcda_32.λ == 1.0f0
6157
@test hmcda_32.init_ϵ == 0.0f0
@@ -65,11 +61,14 @@ end
6561
rng = MersenneTwister(0)
6662
θ_init = randn(rng, 2)
6763
logdensitymodel = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
68-
_, nuts_state = AbstractMCMC.step(rng, logdensitymodel, nuts; init_params = θ_init)
69-
_, hmc_state = AbstractMCMC.step(rng, logdensitymodel, hmc; init_params = θ_init)
64+
_, nuts_state =
65+
AbstractMCMC.step(rng, logdensitymodel, nuts; n_adapts = 0, init_params = θ_init)
66+
_, hmc_state =
67+
AbstractMCMC.step(rng, logdensitymodel, hmc; n_adapts = 0, init_params = θ_init)
7068
_, nuts_32_state =
71-
AbstractMCMC.step(rng, logdensitymodel, nuts_32; init_params = θ_init)
72-
_, custom_state = AbstractMCMC.step(rng, logdensitymodel, custom; init_params = θ_init)
69+
AbstractMCMC.step(rng, logdensitymodel, nuts_32; n_adapts = 0, init_params = θ_init)
70+
_, custom_state =
71+
AbstractMCMC.step(rng, logdensitymodel, custom; n_adapts = 0, init_params = θ_init)
7372

7473
# Metric
7574
@test typeof(nuts_state.metric) == DiagEuclideanMetric{Float64,Vector{Float64}}

test/sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ end
159159
end
160160
end
161161
@testset "drop_warmup" begin
162-
nuts = NUTS(n_adapts, 0.8)
162+
nuts = NUTS(0.8)
163163
metric = DiagEuclideanMetric(D)
164164
h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
165165
integrator = Leapfrog(ϵ)

0 commit comments

Comments
 (0)