Skip to content

Commit 419c7d9

Browse files
committed
Add riemannian manifold HMC
1 parent e3a56ed commit 419c7d9

File tree

5 files changed

+318
-2
lines changed

5 files changed

+318
-2
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.7.1"
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1112
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -15,6 +16,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1718
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
19+
VecTargets = "8a639fad-7908-4fe4-8003-906e9297f002"
1820

1921
[weakdeps]
2022
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -23,6 +25,9 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2325
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2426
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2527

28+
[sources]
29+
VecTargets = {rev = "main", url = "https://github.com/chalk-lab/VecTargets.jl"}
30+
2631
[extensions]
2732
AdvancedHMCADTypesExt = "ADTypes"
2833
AdvancedHMCComponentArraysExt = "ComponentArrays"
@@ -37,6 +42,7 @@ ArgCheck = "1, 2"
3742
ComponentArrays = "0.15"
3843
CUDA = "3, 4, 5"
3944
DocStringExtensions = "0.8, 0.9"
45+
ForwardDiff = "0.10.38"
4046
LinearAlgebra = "<0.1, 1"
4147
LogDensityProblems = "2"
4248
LogDensityProblemsAD = "1"

src/AdvancedHMC.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module AdvancedHMC
22

33
using Statistics: mean, var, middle
44
using LinearAlgebra:
5-
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
5+
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, diagm, cholesky, UniformScaling, logdet, tr
66
using StatsFuns: logaddexp, logsumexp, loghalf
77
using Random: Random, AbstractRNG
88
using ProgressMeter: ProgressMeter
@@ -19,8 +19,12 @@ using LogDensityProblemsAD: LogDensityProblemsAD
1919

2020
using AbstractMCMC: AbstractMCMC, LogDensityModel
2121

22+
using VecTargets: VecTargets
23+
2224
import StatsBase: sample
2325

26+
using ForwardDiff: ForwardDiff
27+
2428
const DEFAULT_FLOAT_TYPE = typeof(float(0))
2529

2630
include("utilities.jl")
@@ -40,7 +44,7 @@ struct GaussianKinetic <: AbstractKinetic end
4044
export GaussianKinetic
4145

4246
include("metric.jl")
43-
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
47+
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric
4448

4549
include("hamiltonian.jl")
4650
export Hamiltonian
@@ -50,6 +54,11 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
5054
include("riemannian/integrator.jl")
5155
export GeneralizedLeapfrog
5256

57+
include("riemannian/metric.jl")
58+
export IdentityMap, SoftAbsMap, DenseRiemannianMetric
59+
60+
include("riemannian/hamiltonian.jl")
61+
5362
include("trajectory.jl")
5463
export Trajectory,
5564
HMCKernel,

src/riemannian/hamiltonian.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#! Eq (14) of Girolami & Calderhead (2011)
2+
function ∂H∂r(
3+
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, θ::AbstractVecOrMat, r::AbstractVecOrMat
4+
)
5+
H = h.metric.G(θ)
6+
G = h.metric.map(H)
7+
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
8+
end
9+
10+
function ∂H∂θ(
11+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic},
12+
θ::AbstractVecOrMat{T},
13+
r::AbstractVecOrMat{T},
14+
) where {T}
15+
ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ)
16+
G = h.metric.map(h.metric.G(θ))
17+
invG = inv(G)
18+
∂G∂θ = h.metric.∂G∂θ(θ)
19+
d = length(∂ℓπ∂θ)
20+
return DualValue(
21+
ℓπ,
22+
#! Eq (15) of Girolami & Calderhead (2011)
23+
-mapreduce(vcat, 1:d) do i
24+
∂G∂θᵢ = ∂G∂θ[:, :, i]
25+
∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r
26+
# Gr = G \ r
27+
# ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr
28+
# 1 / 2 * tr(invG * ∂G∂θᵢ)
29+
# 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r
30+
end,
31+
)
32+
end
33+
34+
# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29
35+
#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative"
36+
dsoftabsdλ(α, λ) = coth* λ) + λ * α * -csch* α)^2
37+
38+
#! J as defined in middle of the right column of Page 3 of Betancourt (2012)
39+
function make_J::AbstractVector{T}, α::T) where {T<:AbstractFloat}
40+
d = length(λ)
41+
J = Matrix{T}(undef, d, d)
42+
for i in 1:d, j in 1:d
43+
J[i, j] = if (λ[i] == λ[j])
44+
dsoftabsdλ(α, λ[i])
45+
else
46+
((λ[i] * coth* λ[i]) - λ[j] * coth* λ[j])) / (λ[i] - λ[j]))
47+
end
48+
end
49+
return J
50+
end
51+
52+
function ∂H∂θ(
53+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
54+
θ::AbstractVecOrMat{T},
55+
r::AbstractVecOrMat{T},
56+
) where {T}
57+
return ∂H∂θ_cache(h, θ, r)
58+
end
59+
function ∂H∂θ_cache(
60+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
61+
θ::AbstractVecOrMat{T},
62+
r::AbstractVecOrMat{T};
63+
return_cache=false,
64+
cache=nothing,
65+
) where {T}
66+
# Terms that only dependent on θ can be cached in θ-unchanged loops
67+
if isnothing(cache)
68+
ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ)
69+
H = h.metric.G(θ)
70+
∂H∂θ = h.metric.∂G∂θ(θ)
71+
72+
G, Q, λ, softabsλ = softabs(H, h.metric.map.α)
73+
74+
R = diagm(1 ./ softabsλ)
75+
76+
# softabsΛ = diagm(softabsλ)
77+
# M = inv(softabsΛ) * Q' * r
78+
# M = R * Q' * r # equiv to above but avoid inv
79+
80+
J = make_J(λ, h.metric.map.α)
81+
82+
#! Based on the two equations from the right column of Page 3 of Betancourt (2012)
83+
term_1_cached = Q * (R .* J) * Q'
84+
else
85+
ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache
86+
end
87+
d = length(∂ℓπ∂θ)
88+
D = diagm((Q' * r) ./ softabsλ)
89+
term_2_cached = Q * D * J * D * Q'
90+
g =
91+
isdiag ?
92+
-(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ)) :
93+
-mapreduce(vcat, 1:d) do i
94+
∂H∂θᵢ = ∂H∂θ[:, :, i]
95+
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
96+
# NOTE Some further optimization can be done here: cache the 1st product all together
97+
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
98+
end
99+
100+
dv = DualValue(ℓπ, g)
101+
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
102+
end
103+
104+
# QUES Do we want to change everything to position dependent by default?
105+
# Add θ to ∂H∂r for DenseRiemannianMetric
106+
function phasepoint(
107+
h::Hamiltonian{<:DenseRiemannianMetric},
108+
θ::T,
109+
r::T;
110+
ℓπ=∂H∂θ(h, θ),
111+
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
112+
) where {T<:AbstractVecOrMat}
113+
return PhasePoint(θ, r, ℓπ, ℓκ)
114+
end
115+
116+
#! Eq (13) of Girolami & Calderhead (2011)
117+
function neg_energy(
118+
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T
119+
) where {T<:AbstractVecOrMat}
120+
G = h.metric.map(h.metric.G(θ))
121+
D = size(G, 1)
122+
# Need to consider the normalizing term as it is no longer same for different θs
123+
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
124+
mul!(h.metric._temp, inv(G), r)
125+
return -logZ - dot(r, h.metric._temp) / 2
126+
end

src/riemannian/metric.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
abstract type AbstractRiemannianMetric <: AbstractMetric end
2+
3+
abstract type AbstractHessianMap end
4+
5+
struct IdentityMap <: AbstractHessianMap end
6+
7+
(::IdentityMap)(x) = x
8+
9+
struct SoftAbsMap{T} <: AbstractHessianMap
10+
α::T
11+
end
12+
13+
function softabs(X, α=20.0)
14+
F = eigen(X) # ReverseDiff cannot diff through `eigen`
15+
Q = hcat(F.vectors)
16+
λ = F.values
17+
softabsλ = λ .* coth.(α * λ)
18+
return Q * diagm(softabsλ) * Q', Q, λ, softabsλ
19+
end
20+
21+
(map::SoftAbsMap)(x) = softabs(x, map.α)[1]
22+
23+
# TODO Register softabs with ReverseDiff
24+
#! The definition of SoftAbs from Page 3 of Betancourt (2012)
25+
struct DenseRiemannianMetric{
26+
T,
27+
TM<:AbstractHessianMap,
28+
A<:Union{Tuple{Int},Tuple{Int,Int}},
29+
AV<:AbstractVecOrMat{T},
30+
TG,
31+
T∂G∂θ,
32+
} <: AbstractRiemannianMetric
33+
size::A
34+
G::TG # TODO store G⁻¹ here instead
35+
∂G∂θ::T∂G∂θ
36+
map::TM
37+
_temp::AV
38+
end
39+
40+
# TODO Make dense mass matrix support matrix-mode parallel
41+
function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap())
42+
_temp = Vector{Float64}(undef, first(size))
43+
return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp)
44+
end
45+
46+
# Convenient constructor
47+
function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map = IdentityMap())
48+
_Hfunc = VecTargets.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian)
49+
Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug
50+
51+
fstabilize = H -> H + λ * I
52+
Gfunc = x -> begin
53+
H = fstabilize(Hfunc(x)[3])
54+
all(isfinite, H) ? H : diagm(ones(length(x)))
55+
end
56+
_∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize)
57+
∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x))
58+
59+
_temp = Vector{Float64}(undef, first(size))
60+
61+
return DenseRiemannianMetric(size, Gfunc, ∂G∂θfunc, map, _temp)
62+
end
63+
64+
function gen_hess_fwd(func, x::AbstractVector)
65+
function hess(x::AbstractVector)
66+
return nothing, nothing, ForwardDiff.hessian(func, x)
67+
end
68+
return hess
69+
end
70+
71+
#= possible integrate DI for AD-independent fisher information metric
72+
function gen_∂G∂θ_rev(Vfunc, x; f=identity)
73+
_Hfunc = VecTargets.gen_hess(Vfunc, ReverseDiff.track.(x))
74+
Hfunc = x -> _Hfunc(x)[3]
75+
# QUES What's the best output format of this function?
76+
return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...]
77+
end
78+
=#
79+
80+
# Fisher information metric
81+
function gen_∂G∂θ_fwd(Vfunc, x; f=identity)
82+
_Hfunc = gen_hess_fwd(Vfunc, x)
83+
Hfunc = x -> _Hfunc(x)[3]
84+
# QUES What's the best output format of this function?
85+
cfg = ForwardDiff.JacobianConfig(Hfunc, x)
86+
d = length(x)
87+
out = zeros(eltype(x), d^2, d)
88+
return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg)
89+
return out # default output shape [∂H∂x₁; ∂H∂x₂; ...]
90+
end
91+
92+
function reshape_∂G∂θ(H)
93+
d = size(H, 2)
94+
return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3)
95+
end
96+
97+
Base.size(e::DenseRiemannianMetric) = e.size
98+
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
99+
Base.show(io::IO, drm::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric")
100+
101+
function rand_momentum(
102+
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
103+
metric::DenseRiemannianMetric{T},
104+
kinetic,
105+
θ::AbstractVecOrMat,
106+
) where {T}
107+
r = _randn(rng, T, size(metric)...)
108+
G⁻¹ = inv(metric.map(metric.G(θ)))
109+
chol = cholesky(Symmetric(G⁻¹))
110+
ldiv!(chol.U, r)
111+
return r
112+
end

test/riemannian.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using ReTest, Random
2+
using AdvancedHMC, ForwardDiff, AbstractMCMC
3+
using LinearAlgebra
4+
5+
@testset "Multi variate Normal with Riemannian HMC" begin
6+
# Set the number of samples to draw and warmup iterations
7+
n_samples = 2_000
8+
rng = MersenneTwister(1110)
9+
initial_θ = rand(rng, D)
10+
λ = 1e-2
11+
# Define a Hamiltonian system
12+
metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λ)
13+
kinetic = GaussianKinetic()
14+
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ)
15+
16+
# Define a leapfrog solver, with the initial step size chosen heuristically
17+
initial_ϵ = 0.01
18+
integrator = GeneralizedLeapfrog(initial_ϵ, 6)
19+
20+
# Define an HMC sampler with the following components
21+
# - multinomial sampling scheme,
22+
# - generalised No-U-Turn criteria, and
23+
kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8)))
24+
25+
# Run the sampler to draw samples from the specified Gaussian, where
26+
# - `samples` will store the samples
27+
# - `stats` will store diagnostic statistics for each sample
28+
samples, stats = sample(
29+
rng, hamiltonian, kernel, initial_θ, n_samples; progress=true
30+
)
31+
@test length(samples) == n_samples
32+
@test length(stats) == n_samples
33+
end
34+
35+
@testset "Multi variate Normal with Riemannian HMC softabs metric" begin
36+
# Set the number of samples to draw and warmup iterations
37+
n_samples = 2_000
38+
rng = MersenneTwister(1110)
39+
initial_θ = rand(rng, D)
40+
41+
# Define a Hamiltonian system
42+
metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λSoftAbsMap(20.0))
43+
kinetic = GaussianKinetic()
44+
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ)
45+
46+
# Define a leapfrog solver, with the initial step size chosen heuristically
47+
initial_ϵ = 0.01
48+
integrator = GeneralizedLeapfrog(initial_ϵ, 6)
49+
50+
# Define an HMC sampler with the following components
51+
# - multinomial sampling scheme,
52+
# - generalised No-U-Turn criteria, and
53+
kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8)))
54+
55+
# Run the sampler to draw samples from the specified Gaussian, where
56+
# - `samples` will store the samples
57+
# - `stats` will store diagnostic statistics for each sample
58+
samples, stats = sample(
59+
rng, hamiltonian, kernel, initial_θ, n_samples; progress=true
60+
)
61+
@test length(samples) == n_samples
62+
@test length(stats) == n_samples
63+
end

0 commit comments

Comments
 (0)