|
1 | 1 | #! Eq (14) of Girolami & Calderhead (2011) |
| 2 | +"The gradient of the Hamiltonian with respect to the momentum." |
2 | 3 | function ∂H∂r( |
3 | 4 | h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, |
4 | | - θ::AbstractVecOrMat, |
5 | | - r::AbstractVecOrMat, |
| 5 | + θ::AbstractVector, |
| 6 | + r::AbstractVector, |
6 | 7 | ) |
7 | 8 | H = h.metric.G(θ) |
8 | 9 | G = h.metric.map(H) |
9 | | - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't |
| 10 | + return G \ r |
10 | 11 | end |
11 | 12 |
|
| 13 | +""" |
| 14 | +Computes `tr(A*B)` for square n x n matrices `A` and `B` in O(n^2) without computing `A*B`, which would be O(n^3). |
| 15 | +
|
| 16 | +Doesn't actually check that A and B are both n x n matrices. |
| 17 | +""" |
| 18 | +tr_product(A::AbstractMatrix, B::AbstractMatrix) = sum(Base.broadcasted(*, A', B)) |
| 19 | +"Computes `tr(A*v*v')`, i.e. dot(v,A,v)." |
| 20 | +tr_product(A::AbstractMatrix, v::AbstractVector) = sum(Base.broadcasted(*, v, A, v')) |
| 21 | + |
| 22 | + |
12 | 23 | function ∂H∂θ( |
| 24 | + h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, |
| 25 | + θ::AbstractVector, |
| 26 | + r::AbstractVector, |
| 27 | +) |
| 28 | + return first(∂H∂θ_cache(h, θ, r)) |
| 29 | +end |
| 30 | +""" |
| 31 | +
|
| 32 | +""" |
| 33 | +@views function ∂H∂θ_cache( |
13 | 34 | h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, |
14 | | - θ::AbstractVecOrMat{T}, |
15 | | - r::AbstractVecOrMat{T}, |
| 35 | + θ::AbstractVector{T}, |
| 36 | + r::AbstractVector{T}; |
| 37 | + cache=nothing |
16 | 38 | ) where {T} |
17 | | - ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) |
18 | | - G = h.metric.map(h.metric.G(θ)) |
19 | | - invG = inv(G) |
20 | | - ∂G∂θ = h.metric.∂G∂θ(θ) |
21 | | - d = length(∂ℓπ∂θ) |
| 39 | + cache = @something cache begin |
| 40 | + log_density, log_density_gradient = h.∂ℓπ∂θ(θ) |
| 41 | + # h.metric.map is the IdentityMap |
| 42 | + metric = h.metric.G(θ) |
| 43 | + # The metric is inverted to be able to compute `tr_product(inv_metric, ...)` efficiently - |
| 44 | + # but this may still be a bad idea! |
| 45 | + inv_metric = inv(metric) |
| 46 | + metric_sensitivities = h.metric.∂G∂θ(θ) |
| 47 | + rv1 = map(eachindex(log_density_gradient)) do i |
| 48 | + -log_density_gradient[i] + .5 * tr_product(inv_metric, metric_sensitivities[:, :, i]) |
| 49 | + end |
| 50 | + (;log_density, inv_metric, metric_sensitivities, rv1) |
| 51 | + end |
| 52 | + # (;log_density, inv_metric_r, metric_sensitivities, rv1) = cache |
| 53 | + inv_metric_r = cache.inv_metric * r |
22 | 54 | return DualValue( |
23 | | - ℓπ, |
| 55 | + cache.log_density, |
24 | 56 | #! Eq (15) of Girolami & Calderhead (2011) |
25 | | - -mapreduce(vcat, 1:d) do i |
26 | | - ∂G∂θᵢ = ∂G∂θ[:, :, i] |
27 | | - ∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r |
28 | | - # Gr = G \ r |
29 | | - # ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr |
30 | | - # 1 / 2 * tr(invG * ∂G∂θᵢ) |
31 | | - # 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r |
32 | | - end, |
33 | | - ) |
| 57 | + cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i |
| 58 | + .5 * tr_product(cache.metric_sensitivities[:, :, i], inv_metric_r) |
| 59 | + end |
| 60 | + ), cache |
34 | 61 | end |
35 | 62 |
|
36 | | -# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 |
37 | | -#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" |
38 | | -dsoftabsdλ(α, λ) = coth(α * λ) + λ * α * -csch(λ * α)^2 |
39 | | - |
40 | 63 | #! J as defined in middle of the right column of Page 3 of Betancourt (2012) |
41 | 64 | function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} |
42 | 65 | d = length(λ) |
43 | 66 | J = Matrix{T}(undef, d, d) |
44 | 67 | for i in 1:d, j in 1:d |
45 | 68 | J[i, j] = if (λ[i] == λ[j]) |
46 | | - dsoftabsdλ(α, λ[i]) |
| 69 | + # Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 |
| 70 | + #! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" |
| 71 | + coth(α * λ[i]) + λ[i] * α * -csch(λ[i] * α)^2 |
47 | 72 | else |
48 | 73 | ((λ[i] * coth(α * λ[i]) - λ[j] * coth(α * λ[j])) / (λ[i] - λ[j])) |
49 | 74 | end |
50 | 75 | end |
51 | 76 | return J |
52 | 77 | end |
53 | 78 |
|
54 | | -function ∂H∂θ( |
| 79 | +@views function ∂H∂θ_cache( |
55 | 80 | h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, |
56 | | - θ::AbstractVecOrMat{T}, |
57 | | - r::AbstractVecOrMat{T}, |
58 | | -) where {T} |
59 | | - return ∂H∂θ_cache(h, θ, r) |
60 | | -end |
61 | | -function ∂H∂θ_cache( |
62 | | - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, |
63 | | - θ::AbstractVecOrMat{T}, |
64 | | - r::AbstractVecOrMat{T}; |
65 | | - return_cache=false, |
| 81 | + θ::AbstractVector{T}, |
| 82 | + r::AbstractVector{T}; |
66 | 83 | cache=nothing, |
67 | 84 | ) where {T} |
68 | | - # Terms that only dependent on θ can be cached in θ-unchanged loops |
69 | | - if isnothing(cache) |
70 | | - ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) |
71 | | - H = h.metric.G(θ) |
72 | | - ∂H∂θ = h.metric.∂G∂θ(θ) |
73 | | - |
74 | | - G, Q, λ, softabsλ = softabs(H, h.metric.map.α) |
75 | | - |
76 | | - R = diagm(1 ./ softabsλ) |
77 | | - |
78 | | - # softabsΛ = diagm(softabsλ) |
79 | | - # M = inv(softabsΛ) * Q' * r |
80 | | - # M = R * Q' * r # equiv to above but avoid inv |
81 | | - |
| 85 | + cache = @something cache begin |
| 86 | + log_density, log_density_gradient = h.∂ℓπ∂θ(θ) |
| 87 | + premetric = h.metric.G(θ) |
| 88 | + premetric_sensitivities = h.metric.∂G∂θ(θ) |
| 89 | + metric, Q, λ, softabsλ = softabs(premetric, h.metric.map.α) |
82 | 90 | J = make_J(λ, h.metric.map.α) |
83 | 91 |
|
84 | 92 | #! Based on the two equations from the right column of Page 3 of Betancourt (2012) |
85 | | - term_1_cached = Q * (R .* J) * Q' |
86 | | - else |
87 | | - ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache |
88 | | - end |
89 | | - d = length(∂ℓπ∂θ) |
90 | | - D = diagm((Q' * r) ./ softabsλ) |
91 | | - term_2_cached = Q * D * J * D * Q' |
92 | | - g = |
93 | | - -mapreduce(vcat, 1:d) do i |
94 | | - ∂H∂θᵢ = ∂H∂θ[:, :, i] |
95 | | - # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) |
96 | | - # NOTE Some further optimization can be done here: cache the 1st product all together |
97 | | - ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly |
| 93 | + tmpv = diag(J) ./ softabsλ |
| 94 | + tmpm = Q * Diagonal(tmpv) * Q' |
| 95 | + |
| 96 | + rv1 = map(eachindex(log_density_gradient)) do i |
| 97 | + -log_density_gradient[i] + .5 * tr_product(tmpm, premetric_sensitivities[:, :, i]) |
98 | 98 | end |
| 99 | + (;log_density, Q, softabsλ, tmpv, tmpm, rv1) |
| 100 | + end |
| 101 | + cache.tmpv .= (cache.Q' * r) ./ cache.softabsλ |
| 102 | + cache.tmpm .= Q * (J .* cache.tmpv .* cache.tmpv') * Q' |
99 | 103 |
|
100 | | - dv = DualValue(ℓπ, g) |
101 | | - return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv |
| 104 | + return DualValue( |
| 105 | + cache.log_density, |
| 106 | + cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i |
| 107 | + .5 * tr_product(cache.tmpm, cache.premetric_sensitivities[:, :, i]) |
| 108 | + end |
| 109 | + ), cache |
102 | 110 | end |
103 | 111 |
|
104 | 112 | # QUES Do we want to change everything to position dependent by default? |
|
0 commit comments