Skip to content

Commit 57a8ef1

Browse files
devmotionyebaigithub-actions[bot]
authored
Simplify find_good_stepsize (#394)
* Simplify `find_good_stepsize` * Fix typo * Fix merge * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent f539bd1 commit 57a8ef1

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

src/AdvancedHMC.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module AdvancedHMC
33
using Statistics: mean, var, middle
44
using LinearAlgebra:
55
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
6-
using StatsFuns: logaddexp, logsumexp
6+
using StatsFuns: logaddexp, logsumexp, loghalf
77
using Random: Random, AbstractRNG
88
using ProgressMeter: ProgressMeter
99

src/trajectory.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -750,32 +750,35 @@ function find_good_stepsize(
750750
rng::AbstractRNG, h::Hamiltonian, θ::AbstractVector{T}; max_n_iters::Int=100
751751
) where {T<:Real}
752752
# Initialize searching parameters
753-
ϵ′ = ϵ = T(0.1)
754-
a_min, a_cross, a_max = T(0.25), T(0.5), T(0.75) # minimal, crossing, maximal accept ratio
755-
d = T(2.0)
753+
ϵ′ = ϵ = T(1//10)
754+
# minimal, crossing, maximal log accept ratio
755+
log_a_min = 2 * T(loghalf)
756+
log_a_cross = T(loghalf)
757+
log_a_max = log(T(3//4))
758+
d = T(2)
759+
invd = inv(d)
756760
# Create starting phase point
757761
r = rand_momentum(rng, h.metric, h.kinetic, θ) # sample momentum variable
758762
z = phasepoint(h, θ, r)
759763
H = energy(z)
760764

761765
# Make a proposal phase point to decide direction
762-
z′, H′ = A(h, z, ϵ)
766+
_, H′ = A(h, z, ϵ)
763767
ΔH = H - H′ # compute the energy difference; `exp(ΔH)` is the MH accept ratio
764-
direction = ΔH > log(a_cross) ? 1 : -1
768+
ratio_too_high = ΔH > log_a_cross
765769

766770
# Crossing step: increase/decrease ϵ until accept ratio cross a_cross.
767771
for _ in 1:max_n_iters
768-
# `direction` being `1` means MH ratio too high
772+
# `ratio_too_high` being `true` means MH ratio too high
769773
# - this means our step size is too small, thus we increase
770-
# `direction` being `-1` means MH ratio too small
771-
# - this means our step szie is too large, thus we decrease
772-
ϵ′ = direction == 1 ? d * ϵ : 1 / d * ϵ
773-
z′, H′ = A(h, z, ϵ)
774+
# `ratio_too_high` being `false` means MH ratio too small
775+
# - this means our step size is too large, thus we decrease
776+
ϵ′ = ratio_too_high ? d * ϵ : invd * ϵ
777+
_, H′ = A(h, z, ϵ)
774778
ΔH = H - H′
775779
@debug "Crossing step" direction H′ ϵ α = min(1, exp(ΔH))
776-
if (direction == 1) && !(ΔH > log(a_cross))
777-
break
778-
elseif (direction == -1) && !(ΔH < log(a_cross))
780+
# stop if there is no crossing; otherwise, continue to half or double stepsize.
781+
if xor(ratio_too_high, ΔH > log_a_cross)
779782
break
780783
else
781784
ϵ = ϵ′
@@ -787,19 +790,19 @@ function find_good_stepsize(
787790
# Bisection step: ensure final accept ratio: a_min < a < a_max.
788791
# See https://en.wikipedia.org/wiki/Bisection_method
789792

790-
ϵ, ϵ′ = ϵ < ϵ′ ? (ϵ, ϵ′) : (ϵ′, ϵ) # ensure ϵ < ϵ′;
793+
ϵ, ϵ′ = minmax(ϵ, ϵ′) # ensure ϵ < ϵ′;
791794
# Here we want to use a value between these two given the
792795
# criteria that this value also gives us a MH ratio between `a_min` and `a_max`.
793796
# This condition is quite mild and only intended to avoid cases where
794797
# the middle value of `ϵ` and `ϵ′` is too extreme.
795798
for _ in 1:max_n_iters
796799
ϵ_mid = middle(ϵ, ϵ′)
797-
z′, H′ = A(h, z, ϵ_mid)
800+
_, H′ = A(h, z, ϵ_mid)
798801
ΔH = H - H′
799802
@debug "Bisection step" H′ ϵ_mid α = min(1, exp(ΔH))
800-
if (exp(ΔH) > a_max)
803+
if ΔH > log_a_max
801804
ϵ = ϵ_mid
802-
elseif (exp(ΔH) < a_min)
805+
elseif ΔH < log_a_min
803806
ϵ′ = ϵ_mid
804807
else
805808
ϵ = ϵ_mid

0 commit comments

Comments
 (0)