Skip to content

Commit 2dd0cea

Browse files
committed
Merge commit 'f400a07c6e3bd6e7a834e2ad451069594c30747e' into qqy/RMHMC
2 parents 0777d49 + f400a07 commit 2dd0cea

File tree

1 file changed

+358
-0
lines changed

1 file changed

+358
-0
lines changed

src/riemannian/riemannian_hmc.jl

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
using Random
2+
3+
### integrator.jl
4+
5+
import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step
6+
using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size
7+
8+
"""
9+
$(TYPEDEF)
10+
11+
Generalized leapfrog integrator with fixed step size `ϵ`.
12+
13+
# Fields
14+
15+
$(TYPEDFIELDS)
16+
"""
17+
struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
18+
"Step size."
19+
ϵ::T
20+
n::Int
21+
end
22+
function Base.show(io::IO, l::GeneralizedLeapfrog)
23+
return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))")
24+
end
25+
26+
# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ
27+
function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T}
28+
dv = ∂H∂θ(h, θ, r)
29+
return return_cache ? (dv, nothing) : dv
30+
end
31+
32+
# TODO Make sure vectorization works
33+
# TODO Check if tempering is valid
34+
function step(
35+
lf::GeneralizedLeapfrog{T},
36+
h::Hamiltonian,
37+
z::P,
38+
n_steps::Int=1;
39+
fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0
40+
full_trajectory::Val{FullTraj}=Val(false),
41+
) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj}
42+
n_steps = abs(n_steps) # to support `n_steps < 0` cases
43+
44+
ϵ = fwd ? step_size(lf) : -step_size(lf)
45+
ϵ = ϵ'
46+
47+
res = if FullTraj
48+
Vector{P}(undef, n_steps)
49+
else
50+
z
51+
end
52+
53+
for i in 1:n_steps
54+
θ_init, r_init = z.θ, z.r
55+
# Tempering
56+
#r = temper(lf, r, (i=i, is_half=true), n_steps)
57+
#! Eq (16) of Girolami & Calderhead (2011)
58+
r_half = copy(r_init)
59+
local cache
60+
for j in 1:(lf.n)
61+
# Reuse cache for the first iteration
62+
if j == 1
63+
(; value, gradient) = z.ℓπ
64+
elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged)
65+
retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true)
66+
(; value, gradient) = retval
67+
else # reuse cache
68+
(; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache)
69+
end
70+
r_half = r_init - ϵ / 2 * gradient
71+
# println("r_half: ", r_half)
72+
end
73+
#! Eq (17) of Girolami & Calderhead (2011)
74+
θ_full = copy(θ_init)
75+
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
76+
for j in 1:(lf.n)
77+
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
78+
# println("θ_full :", θ_full)
79+
end
80+
#! Eq (18) of Girolami & Calderhead (2011)
81+
(; value, gradient) = ∂H∂θ(h, θ_full, r_half)
82+
r_full = r_half - ϵ / 2 * gradient
83+
# println("r_full: ", r_full)
84+
# Tempering
85+
#r = temper(lf, r, (i=i, is_half=false), n_steps)
86+
# Create a new phase point by caching the logdensity and gradient
87+
z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient))
88+
# Update result
89+
if FullTraj
90+
res[i] = z
91+
else
92+
res = z
93+
end
94+
if !isfinite(z)
95+
# Remove undef
96+
if FullTraj
97+
res = res[isassigned.(Ref(res), 1:n_steps)]
98+
end
99+
break
100+
end
101+
# @assert false
102+
end
103+
return res
104+
end
105+
106+
# TODO Make the order of θ and r consistent with neg_energy
107+
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
108+
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)
109+
110+
### hamiltonian.jl
111+
112+
import AdvancedHMC: refresh, phasepoint
113+
using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric
114+
115+
# To change L180 of hamiltonian.jl
116+
function phasepoint(
117+
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
118+
θ::AbstractVecOrMat{T},
119+
h::Hamiltonian,
120+
) where {T<:Real}
121+
return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ))
122+
end
123+
124+
# To change L191 of hamiltonian.jl
125+
function refresh(
126+
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
127+
::FullMomentumRefreshment,
128+
h::Hamiltonian,
129+
z::PhasePoint,
130+
)
131+
return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ))
132+
end
133+
134+
# To change L215 of hamiltonian.jl
135+
function refresh(
136+
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
137+
ref::PartialMomentumRefreshment,
138+
h::Hamiltonian,
139+
z::PhasePoint,
140+
)
141+
return phasepoint(
142+
h,
143+
z.θ,
144+
ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ),
145+
)
146+
end
147+
148+
### metric.jl
149+
150+
import AdvancedHMC: _rand
151+
using AdvancedHMC: AbstractMetric
152+
using LinearAlgebra: eigen, cholesky, Symmetric
153+
154+
abstract type AbstractRiemannianMetric <: AbstractMetric end
155+
156+
abstract type AbstractHessianMap end
157+
158+
struct IdentityMap <: AbstractHessianMap end
159+
160+
(::IdentityMap)(x) = x
161+
162+
struct SoftAbsMap{T} <: AbstractHessianMap
163+
α::T
164+
end
165+
166+
# TODO Register softabs with ReverseDiff
167+
#! The definition of SoftAbs from Page 3 of Betancourt (2012)
168+
function softabs(X, α=20.0)
169+
F = eigen(X) # ReverseDiff cannot diff through `eigen`
170+
Q = hcat(F.vectors)
171+
λ = F.values
172+
softabsλ = λ .* coth.(α * λ)
173+
return Q * diagm(softabsλ) * Q', Q, λ, softabsλ
174+
end
175+
176+
(map::SoftAbsMap)(x) = softabs(x, map.α)[1]
177+
178+
struct DenseRiemannianMetric{
179+
T,
180+
TM<:AbstractHessianMap,
181+
A<:Union{Tuple{Int},Tuple{Int,Int}},
182+
AV<:AbstractVecOrMat{T},
183+
TG,
184+
T∂G∂θ,
185+
} <: AbstractRiemannianMetric
186+
size::A
187+
G::TG # TODO store G⁻¹ here instead
188+
∂G∂θ::T∂G∂θ
189+
map::TM
190+
_temp::AV
191+
end
192+
193+
# TODO Make dense mass matrix support matrix-mode parallel
194+
function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat}
195+
_temp = Vector{Float64}(undef, size[1])
196+
return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp)
197+
end
198+
# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D))
199+
# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D)
200+
# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz)))
201+
# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz)
202+
203+
# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)
204+
205+
Base.size(e::DenseRiemannianMetric) = e.size
206+
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
207+
Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)")
208+
209+
function rand_momentum(
210+
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
211+
metric::DenseRiemannianMetric{T},
212+
kinetic,
213+
θ::AbstractVecOrMat,
214+
) where {T}
215+
r = _randn(rng, T, size(metric)...)
216+
G⁻¹ = inv(metric.map(metric.G(θ)))
217+
chol = cholesky(Symmetric(G⁻¹))
218+
ldiv!(chol.U, r)
219+
return r
220+
end
221+
222+
### hamiltonian.jl
223+
224+
import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r
225+
using LinearAlgebra: logabsdet, tr
226+
227+
# QUES Do we want to change everything to position dependent by default?
228+
# Add θ to ∂H∂r for DenseRiemannianMetric
229+
function phasepoint(
230+
h::Hamiltonian{<:DenseRiemannianMetric},
231+
θ::T,
232+
r::T;
233+
ℓπ=∂H∂θ(h, θ),
234+
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
235+
) where {T<:AbstractVecOrMat}
236+
return PhasePoint(θ, r, ℓπ, ℓκ)
237+
end
238+
239+
# Negative kinetic energy
240+
#! Eq (13) of Girolami & Calderhead (2011)
241+
function neg_energy(
242+
h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T
243+
) where {T<:AbstractVecOrMat}
244+
G = h.metric.map(h.metric.G(θ))
245+
D = size(G, 1)
246+
# Need to consider the normalizing term as it is no longer same for different θs
247+
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
248+
mul!(h.metric._temp, inv(G), r)
249+
return -logZ - dot(r, h.metric._temp) / 2
250+
end
251+
252+
# QUES L31 of hamiltonian.jl now reads a bit weird (semantically)
253+
function ∂H∂θ(
254+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}},
255+
θ::AbstractVecOrMat{T},
256+
r::AbstractVecOrMat{T},
257+
) where {T}
258+
ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ)
259+
G = h.metric.map(h.metric.G(θ))
260+
invG = inv(G)
261+
∂G∂θ = h.metric.∂G∂θ(θ)
262+
d = length(∂ℓπ∂θ)
263+
return DualValue(
264+
ℓπ,
265+
#! Eq (15) of Girolami & Calderhead (2011)
266+
-mapreduce(vcat, 1:d) do i
267+
∂G∂θᵢ = ∂G∂θ[:, :, i]
268+
∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r
269+
# Gr = G \ r
270+
# ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr
271+
# 1 / 2 * tr(invG * ∂G∂θᵢ)
272+
# 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r
273+
end,
274+
)
275+
end
276+
277+
# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29
278+
#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative"
279+
dsoftabsdλ(α, λ) = coth* λ) + λ * α * -csch* α)^2
280+
281+
#! J as defined in middle of the right column of Page 3 of Betancourt (2012)
282+
function make_J::AbstractVector{T}, α::T) where {T<:AbstractFloat}
283+
d = length(λ)
284+
J = Matrix{T}(undef, d, d)
285+
for i in 1:d, j in 1:d
286+
J[i, j] = if (λ[i] == λ[j])
287+
dsoftabsdλ(α, λ[i])
288+
else
289+
((λ[i] * coth* λ[i]) - λ[j] * coth* λ[j])) / (λ[i] - λ[j]))
290+
end
291+
end
292+
return J
293+
end
294+
295+
function ∂H∂θ(
296+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
297+
θ::AbstractVecOrMat{T},
298+
r::AbstractVecOrMat{T},
299+
) where {T}
300+
return ∂H∂θ_cache(h, θ, r)
301+
end
302+
function ∂H∂θ_cache(
303+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
304+
θ::AbstractVecOrMat{T},
305+
r::AbstractVecOrMat{T};
306+
return_cache=false,
307+
cache=nothing,
308+
) where {T}
309+
# Terms that only dependent on θ can be cached in θ-unchanged loops
310+
if isnothing(cache)
311+
ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ)
312+
H = h.metric.G(θ)
313+
∂H∂θ = h.metric.∂G∂θ(θ)
314+
315+
G, Q, λ, softabsλ = softabs(H, h.metric.map.α)
316+
317+
R = diagm(1 ./ softabsλ)
318+
319+
# softabsΛ = diagm(softabsλ)
320+
# M = inv(softabsΛ) * Q' * r
321+
# M = R * Q' * r # equiv to above but avoid inv
322+
323+
J = make_J(λ, h.metric.map.α)
324+
325+
#! Based on the two equations from the right column of Page 3 of Betancourt (2012)
326+
term_1_cached = Q * (R .* J) * Q'
327+
else
328+
ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache
329+
end
330+
d = length(∂ℓπ∂θ)
331+
D = diagm((Q' * r) ./ softabsλ)
332+
term_2_cached = Q * D * J * D * Q'
333+
g =
334+
-mapreduce(vcat, 1:d) do i
335+
∂H∂θᵢ = ∂H∂θ[:, :, i]
336+
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
337+
# NOTE Some further optimization can be done here: cache the 1st product all together
338+
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
339+
end
340+
341+
dv = DualValue(ℓπ, g)
342+
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
343+
end
344+
345+
#! Eq (14) of Girolami & Calderhead (2011)
346+
function ∂H∂r(
347+
h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat
348+
)
349+
H = h.metric.G(θ)
350+
# if !all(isfinite, H)
351+
# println("θ: ", θ)
352+
# println("H: ", H)
353+
# end
354+
G = h.metric.map(H)
355+
# return inv(G) * r
356+
# println("G \ r: ", G \ r)
357+
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
358+
end

0 commit comments

Comments
 (0)