Skip to content

Commit bb4b534

Browse files
Change default adaptor for HMCDA (#338)
* Update README.md * Update abstractmcmc.jl * Update README.md * Update constructors.jl * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update abstractmcmc.jl --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e302429 commit bb4b534

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ Users can also use the `AbstractMCMC` interface to sample, which is also used in
125125
In order to show how this is done let us start from our previous example where we defined a `LogTargetDensity`, `ℓπ`.
126126

127127
```julia
128+
using AbstractMCMC, LogDensityProblemsAD
128129
# Wrap the previous LogTargetDensity as LogDensityModel
129130
# where ℓπ::LogTargetDensity
130131
model = AdvancedHMC.LogDensityModel(LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ))
@@ -134,7 +135,7 @@ D = 10; initial_θ = rand(D)
134135
n_samples, n_adapts, δ = 1_000, 2_000, 0.8
135136
sampler = HMCSampler(kernel, metric, adaptor)
136137

137-
# Now just sample
138+
# Now sample
138139
samples = AbstractMCMC.sample(
139140
model,
140141
sampler,
@@ -152,8 +153,8 @@ In the previous examples, we built the sampler by manually specifying the integr
152153
```julia
153154
# HMC Sampler
154155
# step size, number of leapfrog steps
155-
n_leapfrogs, lf_integrator = 0.25, Leapfrog(0.1)
156-
hmc = HMC(n_leapfrogs, integrator = lf_integrator)
156+
n_leapfrog, lf_integrator = 25, Leapfrog(0.1)
157+
hmc = HMC(n_leapfrog, integrator = lf_integrator)
157158
```
158159

159160
Equivalent to:
@@ -204,7 +205,7 @@ In the previous examples, we built the sampler by manually specifying the integr
204205
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
205206
integrator = Leapfrog(initial_ϵ)
206207
kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedIntegrationTime(λ)))
207-
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(δ, integrator))
208+
adaptor = StepSizeAdaptor(δ, initial_ϵ)
208209
hmcda = HMCSampler(kernel, metric, adaptor)
209210
```
210211

@@ -217,14 +218,14 @@ This can be done as follows:
217218
nuts = NUTS(δ, metric = :dense) #metric = DenseEuclideanMetric(D)
218219
# Provide your own AbstractMetric
219220
metric = DiagEuclideanMetric(10)
220-
nuts = NUTS(n_adapt, δ, metric = metric)
221+
nuts = NUTS(δ, metric = metric)
221222

222223
nuts = NUTS(δ, integrator = :leapfrog) #integrator = Leapfrog(ϵ) (Default!)
223224
nuts = NUTS(δ, integrator = :jitteredleapfrog) #integrator = JitteredLeapfrog(ϵ, 0.1ϵ)
224225
nuts = NUTS(δ, integrator = :temperedleapfrog) #integrator = TemperedLeapfrog(ϵ, 1.0)
225226

226227
# Provide your own AbstractIntegrator
227-
integrator = JitteredLeapfrog(ϵ, 0.2ϵ)
228+
integrator = JitteredLeapfrog(0.1, 0.2)
228229
nuts = NUTS(δ, integrator = integrator)
229230
```
230231

@@ -237,7 +238,7 @@ A small working example can be found at `test/cuda.jl`.
237238
## API and supported HMC algorithms
238239

239240
An important design goal of AdvancedHMC.jl is modularity; we would like to support algorithmic research on HMC.
240-
This modularity means that different HMC variants can be easily constructed by composing various components, such as preconditioning metric (i.e., mass matrix), leapfrog integrators, trajectories (static or dynamic), and adaption schemes, etc.
241+
This modularity means that different HMC variants can be easily constructed by composing various components, such as preconditioning metric (i.e., mass matrix), leapfrog integrators, trajectories (static or dynamic), adaption schemes, etc.
241242
The minimal example above can be modified to suit particular inference problems by picking components from the list below.
242243

243244
### Hamiltonian mass matrix (`metric`)

src/abstractmcmc.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,14 @@ end
338338

339339
#########
340340

341-
function make_adaptor(
342-
spl::Union{NUTS,HMCDA},
343-
metric::AbstractMetric,
344-
integrator::AbstractIntegrator,
345-
)
341+
function make_adaptor(spl::NUTS, metric::AbstractMetric, integrator::AbstractIntegrator)
346342
return StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(spl.δ, integrator))
347343
end
348344

345+
function make_adaptor(spl::HMCDA, metric::AbstractMetric, integrator::AbstractIntegrator)
346+
return StepSizeAdaptor(spl.δ, integrator)
347+
end
348+
349349
function make_adaptor(spl::HMC, metric::AbstractMetric, integrator::AbstractIntegrator)
350350
return NoAdaptation()
351351
end

test/constructors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ include("common.jl")
1919
(
2020
HMCDA(T(0.8), 1, integrator = Leapfrog(T(0.1))),
2121
(
22-
adaptor_type = StanHMCAdaptor,
22+
adaptor_type = NesterovDualAveraging,
2323
metric_type = DiagEuclideanMetric{T},
2424
integrator_type = Leapfrog{T},
2525
),

0 commit comments

Comments
 (0)