Skip to content

Commit 2b3814c

Browse files
xukai92yebai
andauthored
Generalized leapfrog integrator (#370)
* feat: port generalized leapfrog Signed-off-by: Kai Xu <[email protected]> * refactor: remove copy Signed-off-by: Kai Xu <[email protected]> * format: add emplty line Signed-off-by: Kai Xu <[email protected]> * Update src/riemannian/integrator.jl Co-authored-by: Hong Ge <[email protected]> * chore: add warning for using generalized leapfrog with vectorization Signed-off-by: Kai Xu <[email protected]> * fix: type order Signed-off-by: Kai Xu <[email protected]> --------- Signed-off-by: Kai Xu <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent cfd7227 commit 2b3814c

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

src/AdvancedHMC.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ export Hamiltonian
5252

5353
include("integrator.jl")
5454
export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
55+
include("riemannian/integrator.jl")
56+
export GeneralizedLeapfrog
5557

5658
include("trajectory.jl")
5759
export Trajectory,

src/riemannian/integrator.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
Generalized leapfrog integrator with fixed step size `ϵ`.
5+
6+
# Fields
7+
8+
$(TYPEDFIELDS)
9+
10+
11+
## References
12+
13+
1. Girolami, Mark, and Ben Calderhead. "Riemann manifold Langevin and Hamiltonian Monte Carlo methods." Journal of the Royal Statistical Society Series B: Statistical Methodology 73, no. 2 (2011): 123-214.
14+
"""
15+
struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
16+
"Step size."
17+
ϵ::T
18+
n::Int
19+
end
20+
Base.show(io::IO, l::GeneralizedLeapfrog) =
21+
print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))")
22+
23+
# fallback to ignore return_cache & cache kwargs for other ∂H∂θ
24+
function ∂H∂θ_cache(h, θ, r; return_cache = false, cache = nothing)
25+
dv = ∂H∂θ(h, θ, r)
26+
return return_cache ? (dv, nothing) : dv
27+
end
28+
29+
# TODO(Kai) make sure vectorization works
30+
# TODO(Kai) check if tempering is valid
31+
# TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl`
32+
function step(
33+
lf::GeneralizedLeapfrog{T},
34+
h::Hamiltonian,
35+
z::P,
36+
n_steps::Int = 1;
37+
fwd::Bool = n_steps > 0, # simulate hamiltonian backward when n_steps < 0
38+
full_trajectory::Val{FullTraj} = Val(false),
39+
) where {T<:AbstractScalarOrVec{<:AbstractFloat},TP,P<:PhasePoint{TP},FullTraj}
40+
n_steps = abs(n_steps) # to support `n_steps < 0` cases
41+
42+
ϵ = fwd ? step_size(lf) : -step_size(lf)
43+
ϵ = ϵ'
44+
45+
if !(T <: AbstractFloat) || !(TP <: AbstractVector)
46+
@warn "Vectorization is not tested for GeneralizedLeapfrog."
47+
end
48+
49+
res = if FullTraj
50+
Vector{P}(undef, n_steps)
51+
else
52+
z
53+
end
54+
55+
for i = 1:n_steps
56+
θ_init, r_init = z.θ, z.r
57+
# Tempering
58+
#r = temper(lf, r, (i=i, is_half=true), n_steps)
59+
# eq (16) of Girolami & Calderhead (2011)
60+
r_half = r_init
61+
local cache
62+
for j = 1:lf.n
63+
# Reuse cache for the first iteration
64+
if j == 1
65+
@unpack value, gradient = z.ℓπ
66+
elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged)
67+
retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache = true)
68+
@unpack value, gradient = retval
69+
else # reuse cache
70+
@unpack value, gradient = ∂H∂θ_cache(h, θ_init, r_half; cache = cache)
71+
end
72+
r_half = r_init - ϵ / 2 * gradient
73+
end
74+
# eq (17) of Girolami & Calderhead (2011)
75+
θ_full = θ_init
76+
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
77+
for j = 1:lf.n
78+
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
79+
end
80+
# eq (18) of Girolami & Calderhead (2011)
81+
@unpack value, gradient = ∂H∂θ(h, θ_full, r_half)
82+
r_full = r_half - ϵ / 2 * gradient
83+
# Tempering
84+
#r = temper(lf, r, (i=i, is_half=false), n_steps)
85+
# Create a new phase point by caching the logdensity and gradient
86+
z = phasepoint(h, θ_full, r_full; ℓπ = DualValue(value, gradient))
87+
# Update result
88+
if FullTraj
89+
res[i] = z
90+
else
91+
res = z
92+
end
93+
if !isfinite(z)
94+
# Remove undef
95+
if FullTraj
96+
res = res[isassigned.(Ref(res), 1:n_steps)]
97+
end
98+
break
99+
end
100+
end
101+
return res
102+
end

0 commit comments

Comments
 (0)