@@ -750,32 +750,35 @@ function find_good_stepsize(
750
750
rng:: AbstractRNG , h:: Hamiltonian , θ:: AbstractVector{T} ; max_n_iters:: Int = 100
751
751
) where {T<: Real }
752
752
# Initialize searching parameters
753
- ϵ′ = ϵ = T (0.1 )
754
- a_min, a_cross, a_max = T (0.25 ), T (0.5 ), T (0.75 ) # minimal, crossing, maximal accept ratio
755
- d = T (2.0 )
753
+ ϵ′ = ϵ = T (1 // 10 )
754
+ # minimal, crossing, maximal log accept ratio
755
+ log_a_min = 2 * T (loghalf)
756
+ log_a_cross = T (loghalf)
757
+ log_a_max = log (T (3 // 4 ))
758
+ d = T (2 )
759
+ invd = inv (d)
756
760
# Create starting phase point
757
761
r = rand_momentum (rng, h. metric, h. kinetic, θ) # sample momentum variable
758
762
z = phasepoint (h, θ, r)
759
763
H = energy (z)
760
764
761
765
# Make a proposal phase point to decide direction
762
- z′ , H′ = A (h, z, ϵ)
766
+ _ , H′ = A (h, z, ϵ)
763
767
ΔH = H - H′ # compute the energy difference; `exp(ΔH)` is the MH accept ratio
764
- direction = ΔH > log (a_cross) ? 1 : - 1
768
+ ratio_too_high = ΔH > log_a_cross
765
769
766
770
# Crossing step: increase/decrease ϵ until accept ratio cross a_cross.
767
771
for _ in 1 : max_n_iters
768
- # `direction ` being `1 ` means MH ratio too high
772
+ # `ratio_too_high ` being `true ` means MH ratio too high
769
773
# - this means our step size is too small, thus we increase
770
- # `direction ` being `-1 ` means MH ratio too small
771
- # - this means our step szie is too large, thus we decrease
772
- ϵ′ = direction == 1 ? d * ϵ : 1 / d * ϵ
773
- z′ , H′ = A (h, z, ϵ)
774
+ # `ratio_too_high ` being `false ` means MH ratio too small
775
+ # - this means our step size is too large, thus we decrease
776
+ ϵ′ = ratio_too_high ? d * ϵ : invd * ϵ
777
+ _ , H′ = A (h, z, ϵ)
774
778
ΔH = H - H′
775
779
@debug " Crossing step" direction H′ ϵ α = min (1 , exp (ΔH))
776
- if (direction == 1 ) && ! (ΔH > log (a_cross))
777
- break
778
- elseif (direction == - 1 ) && ! (ΔH < log (a_cross))
780
+ # stop if there is no crossing; otherwise, continue to half or double stepsize.
781
+ if xor (ratio_too_high, ΔH > log_a_cross)
779
782
break
780
783
else
781
784
ϵ = ϵ′
@@ -787,19 +790,19 @@ function find_good_stepsize(
787
790
# Bisection step: ensure final accept ratio: a_min < a < a_max.
788
791
# See https://en.wikipedia.org/wiki/Bisection_method
789
792
790
- ϵ, ϵ′ = ϵ < ϵ′ ? (ϵ, ϵ′) : (ϵ′, ϵ) # ensure ϵ < ϵ′;
793
+ ϵ, ϵ′ = minmax (ϵ, ϵ′) # ensure ϵ < ϵ′;
791
794
# Here we want to use a value between these two given the
792
795
# criteria that this value also gives us a MH ratio between `a_min` and `a_max`.
793
796
# This condition is quite mild and only intended to avoid cases where
794
797
# the middle value of `ϵ` and `ϵ′` is too extreme.
795
798
for _ in 1 : max_n_iters
796
799
ϵ_mid = middle (ϵ, ϵ′)
797
- z′ , H′ = A (h, z, ϵ_mid)
800
+ _ , H′ = A (h, z, ϵ_mid)
798
801
ΔH = H - H′
799
802
@debug " Bisection step" H′ ϵ_mid α = min (1 , exp (ΔH))
800
- if ( exp (ΔH) > a_max)
803
+ if ΔH > log_a_max
801
804
ϵ = ϵ_mid
802
- elseif ( exp (ΔH) < a_min)
805
+ elseif ΔH < log_a_min
803
806
ϵ′ = ϵ_mid
804
807
else
805
808
ϵ = ϵ_mid
0 commit comments