|
| 1 | +using Random |
| 2 | + |
| 3 | +### integrator.jl |
| 4 | + |
| 5 | +import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step |
| 6 | +using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size |
| 7 | + |
| 8 | +""" |
| 9 | +$(TYPEDEF) |
| 10 | +
|
| 11 | +Generalized leapfrog integrator with fixed step size `ϵ`. |
| 12 | +
|
| 13 | +# Fields |
| 14 | +
|
| 15 | +$(TYPEDFIELDS) |
| 16 | +""" |
| 17 | +struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} |
| 18 | + "Step size." |
| 19 | + ϵ::T |
| 20 | + n::Int |
| 21 | +end |
| 22 | +function Base.show(io::IO, l::GeneralizedLeapfrog) |
| 23 | + return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") |
| 24 | +end |
| 25 | + |
| 26 | +# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ |
| 27 | +function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T} |
| 28 | + dv = ∂H∂θ(h, θ, r) |
| 29 | + return return_cache ? (dv, nothing) : dv |
| 30 | +end |
| 31 | + |
| 32 | +# TODO Make sure vectorization works |
| 33 | +# TODO Check if tempering is valid |
| 34 | +function step( |
| 35 | + lf::GeneralizedLeapfrog{T}, |
| 36 | + h::Hamiltonian, |
| 37 | + z::P, |
| 38 | + n_steps::Int=1; |
| 39 | + fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 |
| 40 | + full_trajectory::Val{FullTraj}=Val(false), |
| 41 | +) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} |
| 42 | + n_steps = abs(n_steps) # to support `n_steps < 0` cases |
| 43 | + |
| 44 | + ϵ = fwd ? step_size(lf) : -step_size(lf) |
| 45 | + ϵ = ϵ' |
| 46 | + |
| 47 | + res = if FullTraj |
| 48 | + Vector{P}(undef, n_steps) |
| 49 | + else |
| 50 | + z |
| 51 | + end |
| 52 | + |
| 53 | + for i in 1:n_steps |
| 54 | + θ_init, r_init = z.θ, z.r |
| 55 | + # Tempering |
| 56 | + #r = temper(lf, r, (i=i, is_half=true), n_steps) |
| 57 | + #! Eq (16) of Girolami & Calderhead (2011) |
| 58 | + r_half = copy(r_init) |
| 59 | + local cache |
| 60 | + for j in 1:(lf.n) |
| 61 | + # Reuse cache for the first iteration |
| 62 | + if j == 1 |
| 63 | + (; value, gradient) = z.ℓπ |
| 64 | + elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) |
| 65 | + retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) |
| 66 | + (; value, gradient) = retval |
| 67 | + else # reuse cache |
| 68 | + (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) |
| 69 | + end |
| 70 | + r_half = r_init - ϵ / 2 * gradient |
| 71 | + # println("r_half: ", r_half) |
| 72 | + end |
| 73 | + #! Eq (17) of Girolami & Calderhead (2011) |
| 74 | + θ_full = copy(θ_init) |
| 75 | + term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop |
| 76 | + for j in 1:(lf.n) |
| 77 | + θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) |
| 78 | + # println("θ_full :", θ_full) |
| 79 | + end |
| 80 | + #! Eq (18) of Girolami & Calderhead (2011) |
| 81 | + (; value, gradient) = ∂H∂θ(h, θ_full, r_half) |
| 82 | + r_full = r_half - ϵ / 2 * gradient |
| 83 | + # println("r_full: ", r_full) |
| 84 | + # Tempering |
| 85 | + #r = temper(lf, r, (i=i, is_half=false), n_steps) |
| 86 | + # Create a new phase point by caching the logdensity and gradient |
| 87 | + z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) |
| 88 | + # Update result |
| 89 | + if FullTraj |
| 90 | + res[i] = z |
| 91 | + else |
| 92 | + res = z |
| 93 | + end |
| 94 | + if !isfinite(z) |
| 95 | + # Remove undef |
| 96 | + if FullTraj |
| 97 | + res = res[isassigned.(Ref(res), 1:n_steps)] |
| 98 | + end |
| 99 | + break |
| 100 | + end |
| 101 | + # @assert false |
| 102 | + end |
| 103 | + return res |
| 104 | +end |
| 105 | + |
| 106 | +# TODO Make the order of θ and r consistent with neg_energy |
| 107 | +∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) |
| 108 | +∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) |
| 109 | + |
| 110 | +### hamiltonian.jl |
| 111 | + |
| 112 | +import AdvancedHMC: refresh, phasepoint |
| 113 | +using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric |
| 114 | + |
| 115 | +# To change L180 of hamiltonian.jl |
| 116 | +function phasepoint( |
| 117 | + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
| 118 | + θ::AbstractVecOrMat{T}, |
| 119 | + h::Hamiltonian, |
| 120 | +) where {T<:Real} |
| 121 | + return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) |
| 122 | +end |
| 123 | + |
| 124 | +# To change L191 of hamiltonian.jl |
| 125 | +function refresh( |
| 126 | + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
| 127 | + ::FullMomentumRefreshment, |
| 128 | + h::Hamiltonian, |
| 129 | + z::PhasePoint, |
| 130 | +) |
| 131 | + return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) |
| 132 | +end |
| 133 | + |
| 134 | +# To change L215 of hamiltonian.jl |
| 135 | +function refresh( |
| 136 | + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
| 137 | + ref::PartialMomentumRefreshment, |
| 138 | + h::Hamiltonian, |
| 139 | + z::PhasePoint, |
| 140 | +) |
| 141 | + return phasepoint( |
| 142 | + h, |
| 143 | + z.θ, |
| 144 | + ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ), |
| 145 | + ) |
| 146 | +end |
| 147 | + |
| 148 | +### metric.jl |
| 149 | + |
| 150 | +import AdvancedHMC: _rand |
| 151 | +using AdvancedHMC: AbstractMetric |
| 152 | +using LinearAlgebra: eigen, cholesky, Symmetric |
| 153 | + |
| 154 | +abstract type AbstractRiemannianMetric <: AbstractMetric end |
| 155 | + |
| 156 | +abstract type AbstractHessianMap end |
| 157 | + |
| 158 | +struct IdentityMap <: AbstractHessianMap end |
| 159 | + |
| 160 | +(::IdentityMap)(x) = x |
| 161 | + |
| 162 | +struct SoftAbsMap{T} <: AbstractHessianMap |
| 163 | + α::T |
| 164 | +end |
| 165 | + |
| 166 | +# TODO Register softabs with ReverseDiff |
| 167 | +#! The definition of SoftAbs from Page 3 of Betancourt (2012) |
| 168 | +function softabs(X, α=20.0) |
| 169 | + F = eigen(X) # ReverseDiff cannot diff through `eigen` |
| 170 | + Q = hcat(F.vectors) |
| 171 | + λ = F.values |
| 172 | + softabsλ = λ .* coth.(α * λ) |
| 173 | + return Q * diagm(softabsλ) * Q', Q, λ, softabsλ |
| 174 | +end |
| 175 | + |
| 176 | +(map::SoftAbsMap)(x) = softabs(x, map.α)[1] |
| 177 | + |
| 178 | +struct DenseRiemannianMetric{ |
| 179 | + T, |
| 180 | + TM<:AbstractHessianMap, |
| 181 | + A<:Union{Tuple{Int},Tuple{Int,Int}}, |
| 182 | + AV<:AbstractVecOrMat{T}, |
| 183 | + TG, |
| 184 | + T∂G∂θ, |
| 185 | +} <: AbstractRiemannianMetric |
| 186 | + size::A |
| 187 | + G::TG # TODO store G⁻¹ here instead |
| 188 | + ∂G∂θ::T∂G∂θ |
| 189 | + map::TM |
| 190 | + _temp::AV |
| 191 | +end |
| 192 | + |
| 193 | +# TODO Make dense mass matrix support matrix-mode parallel |
| 194 | +function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} |
| 195 | + _temp = Vector{Float64}(undef, size[1]) |
| 196 | + return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) |
| 197 | +end |
| 198 | +# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D)) |
| 199 | +# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D) |
| 200 | +# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz))) |
| 201 | +# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) |
| 202 | + |
| 203 | +# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) |
| 204 | + |
| 205 | +Base.size(e::DenseRiemannianMetric) = e.size |
| 206 | +Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] |
| 207 | +Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") |
| 208 | + |
| 209 | +function rand_momentum( |
| 210 | + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
| 211 | + metric::DenseRiemannianMetric{T}, |
| 212 | + kinetic, |
| 213 | + θ::AbstractVecOrMat, |
| 214 | +) where {T} |
| 215 | + r = _randn(rng, T, size(metric)...) |
| 216 | + G⁻¹ = inv(metric.map(metric.G(θ))) |
| 217 | + chol = cholesky(Symmetric(G⁻¹)) |
| 218 | + ldiv!(chol.U, r) |
| 219 | + return r |
| 220 | +end |
| 221 | + |
| 222 | +### hamiltonian.jl |
| 223 | + |
| 224 | +import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r |
| 225 | +using LinearAlgebra: logabsdet, tr |
| 226 | + |
| 227 | +# QUES Do we want to change everything to position dependent by default? |
| 228 | +# Add θ to ∂H∂r for DenseRiemannianMetric |
| 229 | +function phasepoint( |
| 230 | + h::Hamiltonian{<:DenseRiemannianMetric}, |
| 231 | + θ::T, |
| 232 | + r::T; |
| 233 | + ℓπ=∂H∂θ(h, θ), |
| 234 | + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), |
| 235 | +) where {T<:AbstractVecOrMat} |
| 236 | + return PhasePoint(θ, r, ℓπ, ℓκ) |
| 237 | +end |
| 238 | + |
| 239 | +# Negative kinetic energy |
| 240 | +#! Eq (13) of Girolami & Calderhead (2011) |
| 241 | +function neg_energy( |
| 242 | + h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T |
| 243 | +) where {T<:AbstractVecOrMat} |
| 244 | + G = h.metric.map(h.metric.G(θ)) |
| 245 | + D = size(G, 1) |
| 246 | + # Need to consider the normalizing term as it is no longer same for different θs |
| 247 | + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined |
| 248 | + mul!(h.metric._temp, inv(G), r) |
| 249 | + return -logZ - dot(r, h.metric._temp) / 2 |
| 250 | +end |
| 251 | + |
| 252 | +# QUES L31 of hamiltonian.jl now reads a bit weird (semantically) |
| 253 | +function ∂H∂θ( |
| 254 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}}, |
| 255 | + θ::AbstractVecOrMat{T}, |
| 256 | + r::AbstractVecOrMat{T}, |
| 257 | +) where {T} |
| 258 | + ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) |
| 259 | + G = h.metric.map(h.metric.G(θ)) |
| 260 | + invG = inv(G) |
| 261 | + ∂G∂θ = h.metric.∂G∂θ(θ) |
| 262 | + d = length(∂ℓπ∂θ) |
| 263 | + return DualValue( |
| 264 | + ℓπ, |
| 265 | + #! Eq (15) of Girolami & Calderhead (2011) |
| 266 | + -mapreduce(vcat, 1:d) do i |
| 267 | + ∂G∂θᵢ = ∂G∂θ[:, :, i] |
| 268 | + ∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r |
| 269 | + # Gr = G \ r |
| 270 | + # ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr |
| 271 | + # 1 / 2 * tr(invG * ∂G∂θᵢ) |
| 272 | + # 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r |
| 273 | + end, |
| 274 | + ) |
| 275 | +end |
| 276 | + |
| 277 | +# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 |
| 278 | +#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" |
| 279 | +dsoftabsdλ(α, λ) = coth(α * λ) + λ * α * -csch(λ * α)^2 |
| 280 | + |
| 281 | +#! J as defined in middle of the right column of Page 3 of Betancourt (2012) |
| 282 | +function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} |
| 283 | + d = length(λ) |
| 284 | + J = Matrix{T}(undef, d, d) |
| 285 | + for i in 1:d, j in 1:d |
| 286 | + J[i, j] = if (λ[i] == λ[j]) |
| 287 | + dsoftabsdλ(α, λ[i]) |
| 288 | + else |
| 289 | + ((λ[i] * coth(α * λ[i]) - λ[j] * coth(α * λ[j])) / (λ[i] - λ[j])) |
| 290 | + end |
| 291 | + end |
| 292 | + return J |
| 293 | +end |
| 294 | + |
| 295 | +function ∂H∂θ( |
| 296 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, |
| 297 | + θ::AbstractVecOrMat{T}, |
| 298 | + r::AbstractVecOrMat{T}, |
| 299 | +) where {T} |
| 300 | + return ∂H∂θ_cache(h, θ, r) |
| 301 | +end |
| 302 | +function ∂H∂θ_cache( |
| 303 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, |
| 304 | + θ::AbstractVecOrMat{T}, |
| 305 | + r::AbstractVecOrMat{T}; |
| 306 | + return_cache=false, |
| 307 | + cache=nothing, |
| 308 | +) where {T} |
| 309 | + # Terms that only dependent on θ can be cached in θ-unchanged loops |
| 310 | + if isnothing(cache) |
| 311 | + ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) |
| 312 | + H = h.metric.G(θ) |
| 313 | + ∂H∂θ = h.metric.∂G∂θ(θ) |
| 314 | + |
| 315 | + G, Q, λ, softabsλ = softabs(H, h.metric.map.α) |
| 316 | + |
| 317 | + R = diagm(1 ./ softabsλ) |
| 318 | + |
| 319 | + # softabsΛ = diagm(softabsλ) |
| 320 | + # M = inv(softabsΛ) * Q' * r |
| 321 | + # M = R * Q' * r # equiv to above but avoid inv |
| 322 | + |
| 323 | + J = make_J(λ, h.metric.map.α) |
| 324 | + |
| 325 | + #! Based on the two equations from the right column of Page 3 of Betancourt (2012) |
| 326 | + term_1_cached = Q * (R .* J) * Q' |
| 327 | + else |
| 328 | + ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache |
| 329 | + end |
| 330 | + d = length(∂ℓπ∂θ) |
| 331 | + D = diagm((Q' * r) ./ softabsλ) |
| 332 | + term_2_cached = Q * D * J * D * Q' |
| 333 | + g = |
| 334 | + -mapreduce(vcat, 1:d) do i |
| 335 | + ∂H∂θᵢ = ∂H∂θ[:, :, i] |
| 336 | + # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) |
| 337 | + # NOTE Some further optimization can be done here: cache the 1st product all together |
| 338 | + ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly |
| 339 | + end |
| 340 | + |
| 341 | + dv = DualValue(ℓπ, g) |
| 342 | + return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv |
| 343 | +end |
| 344 | + |
| 345 | +#! Eq (14) of Girolami & Calderhead (2011) |
| 346 | +function ∂H∂r( |
| 347 | + h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat |
| 348 | +) |
| 349 | + H = h.metric.G(θ) |
| 350 | + # if !all(isfinite, H) |
| 351 | + # println("θ: ", θ) |
| 352 | + # println("H: ", H) |
| 353 | + # end |
| 354 | + G = h.metric.map(H) |
| 355 | + # return inv(G) * r |
| 356 | + # println("G \ r: ", G \ r) |
| 357 | + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't |
| 358 | +end |
0 commit comments