Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This modularity means that different HMC variants can be easily constructed by c
- Unit metric: `UnitEuclideanMetric(dim)`
- Diagonal metric: `DiagEuclideanMetric(dim)`
- Dense metric: `DenseEuclideanMetric(dim)`
- Dense Riemannian metric: `DenseRiemannianMetric(size, G, ∂G∂θ)`

where `dim` is the dimensionality of the sampling space.

Expand Down
21 changes: 19 additions & 2 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@ module AdvancedHMC

using Statistics: mean, var, middle
using LinearAlgebra:
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
Symmetric,
UpperTriangular,
mul!,
ldiv!,
dot,
I,
diag,
cholesky,
UniformScaling,
logdet,
tr,
eigen,
diagm
using StatsFuns: logaddexp, logsumexp, loghalf
using Random: Random, AbstractRNG
using ProgressMeter: ProgressMeter
Expand Down Expand Up @@ -40,7 +52,7 @@ struct GaussianKinetic <: AbstractKinetic end
export GaussianKinetic

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric

include("hamiltonian.jl")
export Hamiltonian
Expand All @@ -50,6 +62,11 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
include("riemannian/integrator.jl")
export GeneralizedLeapfrog

include("riemannian/metric.jl")
export IdentityMap, SoftAbsMap, DenseRiemannianMetric

include("riemannian/hamiltonian.jl")

include("trajectory.jl")
export Trajectory,
HMCKernel,
Expand Down
298 changes: 33 additions & 265 deletions src/riemannian/hamiltonian.jl
Original file line number Diff line number Diff line change
@@ -1,257 +1,16 @@
using Random

### integrator.jl

import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step
using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size

"""
$(TYPEDEF)

Generalized leapfrog integrator with fixed step size `ϵ`.

# Fields

$(TYPEDFIELDS)
"""
struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
"Step size."
ϵ::T
n::Int
end
function Base.show(io::IO, l::GeneralizedLeapfrog)
return print(io, "GeneralizedLeapfrog(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")")
end

# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ
function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T}
dv = ∂H∂θ(h, θ, r)
return return_cache ? (dv, nothing) : dv
end

# TODO Make sure vectorization works
# TODO Check if tempering is valid
function step(
lf::GeneralizedLeapfrog{T},
h::Hamiltonian,
z::P,
n_steps::Int=1;
fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj}=Val(false),
) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj}
n_steps = abs(n_steps) # to support `n_steps < 0` cases

ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'

res = if FullTraj
Vector{P}(undef, n_steps)
else
z
end

for i in 1:n_steps
θ_init, r_init = z.θ, z.r
# Tempering
#r = temper(lf, r, (i=i, is_half=true), n_steps)
#! Eq (16) of Girolami & Calderhead (2011)
r_half = copy(r_init)
local cache
for j in 1:(lf.n)
# Reuse cache for the first iteration
if j == 1
(; value, gradient) = z.ℓπ
elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged)
retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true)
(; value, gradient) = retval
else # reuse cache
(; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache)
end
r_half = r_init - ϵ / 2 * gradient
# println("r_half: ", r_half)
end
#! Eq (17) of Girolami & Calderhead (2011)
θ_full = copy(θ_init)
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
for j in 1:(lf.n)
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
# println("θ_full :", θ_full)
end
#! Eq (18) of Girolami & Calderhead (2011)
(; value, gradient) = ∂H∂θ(h, θ_full, r_half)
r_full = r_half - ϵ / 2 * gradient
# println("r_full: ", r_full)
# Tempering
#r = temper(lf, r, (i=i, is_half=false), n_steps)
# Create a new phase point by caching the logdensity and gradient
z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient))
# Update result
if FullTraj
res[i] = z
else
res = z
end
if !isfinite(z)
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
# @assert false
end
return res
end

# TODO Make the order of θ and r consistent with neg_energy
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)

### hamiltonian.jl

import AdvancedHMC: refresh, phasepoint
using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric

# To change L180 of hamiltonian.jl
function phasepoint(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
θ::AbstractVecOrMat{T},
h::Hamiltonian,
) where {T<:Real}
return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ))
end

# To change L191 of hamiltonian.jl
function refresh(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
::FullMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint,
)
return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ))
end

# To change L215 of hamiltonian.jl
function refresh(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
ref::PartialMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint,
)
return phasepoint(
h,
z.θ,
ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ),
)
end

### metric.jl

import AdvancedHMC: _rand
using AdvancedHMC: AbstractMetric
using LinearAlgebra: eigen, cholesky, Symmetric

abstract type AbstractRiemannianMetric <: AbstractMetric end

abstract type AbstractHessianMap end

struct IdentityMap <: AbstractHessianMap end

(::IdentityMap)(x) = x

struct SoftAbsMap{T} <: AbstractHessianMap
α::T
end

# TODO Register softabs with ReverseDiff
#! The definition of SoftAbs from Page 3 of Betancourt (2012)
function softabs(X, α=20.0)
F = eigen(X) # ReverseDiff cannot diff through `eigen`
Q = hcat(F.vectors)
λ = F.values
softabsλ = λ .* coth.(α * λ)
return Q * diagm(softabsλ) * Q', Q, λ, softabsλ
end

(map::SoftAbsMap)(x) = softabs(x, map.α)[1]

struct DenseRiemannianMetric{
T,
TM<:AbstractHessianMap,
A<:Union{Tuple{Int},Tuple{Int,Int}},
AV<:AbstractVecOrMat{T},
TG,
T∂G∂θ,
} <: AbstractRiemannianMetric
size::A
G::TG # TODO store G⁻¹ here instead
∂G∂θ::T∂G∂θ
map::TM
_temp::AV
end

# TODO Make dense mass matrix support matrix-mode parallel
function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat}
_temp = Vector{Float64}(undef, size[1])
return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp)
end
# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D))
# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D)
# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz)))
# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz)

# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)

Base.size(e::DenseRiemannianMetric) = e.size
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)")

function rand_momentum(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::DenseRiemannianMetric{T},
kinetic,
#! Eq (14) of Girolami & Calderhead (2011)
function ∂H∂r(

Check warning on line 2 in src/riemannian/hamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/riemannian/hamiltonian.jl#L2

Added line #L2 was not covered by tests
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AbstractRiemannianMetric?

This logic is not unique to the dense metric.

θ::AbstractVecOrMat,
) where {T}
r = _randn(rng, T, size(metric)...)
G⁻¹ = inv(metric.map(metric.G(θ)))
chol = cholesky(Symmetric(G⁻¹))
ldiv!(chol.U, r)
return r
end

### hamiltonian.jl

import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r
using LinearAlgebra: logabsdet, tr

# QUES Do we want to change everything to position dependent by default?
# Add θ to ∂H∂r for DenseRiemannianMetric
function phasepoint(
h::Hamiltonian{<:DenseRiemannianMetric},
θ::T,
r::T;
ℓπ=∂H∂θ(h, θ),
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
) where {T<:AbstractVecOrMat}
return PhasePoint(θ, r, ℓπ, ℓκ)
end

# Negative kinetic energy
#! Eq (13) of Girolami & Calderhead (2011)
function neg_energy(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Eq (13) of Girolami & Calderhead (2011) also contains the log-likelihood term, -L(θ). Why is this not included here?

Even if this is correct, we should clarify the naming conventions as it's quite hard to follow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - created an issue here: #483

h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T
) where {T<:AbstractVecOrMat}
G = h.metric.map(h.metric.G(θ))
D = size(G, 1)
# Need to consider the normalizing term as it is no longer same for different θs
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
mul!(h.metric._temp, inv(G), r)
return -logZ - dot(r, h.metric._temp) / 2
r::AbstractVecOrMat,
)
H = h.metric.G(θ)
G = h.metric.map(H)
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't

Check warning on line 9 in src/riemannian/hamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/riemannian/hamiltonian.jl#L7-L9

Added lines #L7 - L9 were not covered by tests
end

# QUES L31 of hamiltonian.jl now reads a bit weird (semantically)
function ∂H∂θ(
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}},
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic},
θ::AbstractVecOrMat{T},
r::AbstractVecOrMat{T},
) where {T}
Expand Down Expand Up @@ -293,14 +52,14 @@
end

function ∂H∂θ(
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
θ::AbstractVecOrMat{T},
r::AbstractVecOrMat{T},
) where {T}
return ∂H∂θ_cache(h, θ, r)
end
function ∂H∂θ_cache(
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
θ::AbstractVecOrMat{T},
r::AbstractVecOrMat{T};
return_cache=false,
Expand Down Expand Up @@ -342,17 +101,26 @@
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
end

#! Eq (14) of Girolami & Calderhead (2011)
function ∂H∂r(
h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat
)
H = h.metric.G(θ)
# if !all(isfinite, H)
# println("θ: ", θ)
# println("H: ", H)
# end
G = h.metric.map(H)
# return inv(G) * r
# println("G \ r: ", G \ r)
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
# QUES Do we want to change everything to position dependent by default?
# Add θ to ∂H∂r for DenseRiemannianMetric
function phasepoint(

Check warning on line 106 in src/riemannian/hamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/riemannian/hamiltonian.jl#L106

Added line #L106 was not covered by tests
h::Hamiltonian{<:DenseRiemannianMetric},
θ::T,
r::T;
ℓπ=∂H∂θ(h, θ),
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
) where {T<:AbstractVecOrMat}
return PhasePoint(θ, r, ℓπ, ℓκ)

Check warning on line 113 in src/riemannian/hamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/riemannian/hamiltonian.jl#L113

Added line #L113 was not covered by tests
end

#! Eq (13) of Girolami & Calderhead (2011)
function neg_energy(

Check warning on line 117 in src/riemannian/hamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/riemannian/hamiltonian.jl#L117

Added line #L117 was not covered by tests
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T
) where {T<:AbstractVecOrMat}
G = h.metric.map(h.metric.G(θ))
D = size(G, 1)

Check warning on line 121 in src/riemannian/hamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/riemannian/hamiltonian.jl#L120-L121

Added lines #L120 - L121 were not covered by tests
# Need to consider the normalizing term as it is no longer same for different θs
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
mul!(h.metric._temp, inv(G), r)
return -logZ - dot(r, h.metric._temp) / 2

Check warning on line 125 in src/riemannian/hamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/riemannian/hamiltonian.jl#L123-L125

Added lines #L123 - L125 were not covered by tests
end
Loading
Loading