Skip to content

Commit 280ca15

Browse files
committed
start minimal refactor for merging into main
1 parent 26266c7 commit 280ca15

File tree

2 files changed

+83
-67
lines changed

2 files changed

+83
-67
lines changed

docs/src/api.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@ This modularity means that different HMC variants can be easily constructed by c
88
- Unit metric: `UnitEuclideanMetric(dim)`
99
- Diagonal metric: `DiagEuclideanMetric(dim)`
1010
- Dense metric: `DenseEuclideanMetric(dim)`
11-
- Dense Riemannian metric: `DenseRiemannianMetric(size, G, ∂G∂θ)`
1211

13-
where `dim` is the dimensionality of the sampling space.
12+
where `dim` is the dimension of the sampling space.
13+
14+
Furthermore, there is now an experimental dense Riemannian metric implementation, specifiable as `DenseRiemannianMetric(dim, premetric, premetric_sensitivities, metric_map=IdentityMap())`, with
15+
16+
- `dim`: again the dimension of the sampling space,
17+
- `premetric`: a function which, for a given posterior position `pos`, computes either
18+
a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or
19+
b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`),
20+
- `premetric_sensitivities`: a function which, again for a given posterior position `pos`, computes the sensitivities with respect to this position of the **`premetric`** function,
21+
- `metric_map=IdentityMap()`: a function which takes in `premetric(pos)` and returns a symmetric positive definite matrix. Provided options are `IdentityMap()` or `SoftAbsMap(alpha)`, with the `SoftAbsMap` type allowing to work directly with the `premetric` returning the Hessian of the log density function, which generally is not guaranteed to be positive definite..
1422

1523
### [Integrator (`integrator`)](@id integrator)
1624

src/riemannian/hamiltonian.jl

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,112 @@
11
#! Eq (14) of Girolami & Calderhead (2011)
2+
"The gradient of the Hamiltonian with respect to the momentum."
23
function ∂H∂r(
34
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic},
4-
θ::AbstractVecOrMat,
5-
r::AbstractVecOrMat,
5+
θ::AbstractVector,
6+
r::AbstractVector,
67
)
78
H = h.metric.G(θ)
89
G = h.metric.map(H)
9-
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
10+
return G \ r
1011
end
1112

13+
"""
14+
Computes `tr(A*B)` for square n x n matrices `A` and `B` in O(n^2) without computing `A*B`, which would be O(n^3).
15+
16+
Doesn't actually check that A and B are both n x n matrices.
17+
"""
18+
tr_product(A::AbstractMatrix, B::AbstractMatrix) = sum(Base.broadcasted(*, A', B))
19+
"Computes `tr(A*v*v')`, i.e. dot(v,A,v)."
20+
tr_product(A::AbstractMatrix, v::AbstractVector) = sum(Base.broadcasted(*, v, A, v'))
21+
22+
1223
function ∂H∂θ(
24+
h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic},
25+
θ::AbstractVector,
26+
r::AbstractVector,
27+
)
28+
return first(∂H∂θ_cache(h, θ, r))
29+
end
30+
"""
31+
32+
"""
33+
@views function ∂H∂θ_cache(
1334
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic},
14-
θ::AbstractVecOrMat{T},
15-
r::AbstractVecOrMat{T},
35+
θ::AbstractVector{T},
36+
r::AbstractVector{T};
37+
cache=nothing
1638
) where {T}
17-
ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ)
18-
G = h.metric.map(h.metric.G(θ))
19-
invG = inv(G)
20-
∂G∂θ = h.metric.∂G∂θ(θ)
21-
d = length(∂ℓπ∂θ)
39+
cache = @something cache begin
40+
log_density, log_density_gradient = h.∂ℓπ∂θ(θ)
41+
# h.metric.map is the IdentityMap
42+
metric = h.metric.G(θ)
43+
# The metric is inverted to be able to compute `tr_product(inv_metric, ...)` efficiently -
44+
# but this may still be a bad idea!
45+
inv_metric = inv(metric)
46+
metric_sensitivities = h.metric.∂G∂θ(θ)
47+
rv1 = map(eachindex(log_density_gradient)) do i
48+
-log_density_gradient[i] + .5 * tr_product(inv_metric, metric_sensitivities[:, :, i])
49+
end
50+
(;log_density, inv_metric, metric_sensitivities, rv1)
51+
end
52+
# (;log_density, inv_metric_r, metric_sensitivities, rv1) = cache
53+
inv_metric_r = cache.inv_metric * r
2254
return DualValue(
23-
ℓπ,
55+
cache.log_density,
2456
#! Eq (15) of Girolami & Calderhead (2011)
25-
-mapreduce(vcat, 1:d) do i
26-
∂G∂θᵢ = ∂G∂θ[:, :, i]
27-
∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r
28-
# Gr = G \ r
29-
# ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr
30-
# 1 / 2 * tr(invG * ∂G∂θᵢ)
31-
# 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r
32-
end,
33-
)
57+
cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i
58+
.5 * tr_product(cache.metric_sensitivities[:, :, i], inv_metric_r)
59+
end
60+
), cache
3461
end
3562

36-
# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29
37-
#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative"
38-
dsoftabsdλ(α, λ) = coth* λ) + λ * α * -csch* α)^2
39-
4063
#! J as defined in middle of the right column of Page 3 of Betancourt (2012)
4164
function make_J::AbstractVector{T}, α::T) where {T<:AbstractFloat}
4265
d = length(λ)
4366
J = Matrix{T}(undef, d, d)
4467
for i in 1:d, j in 1:d
4568
J[i, j] = if (λ[i] == λ[j])
46-
dsoftabsdλ(α, λ[i])
69+
# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29
70+
#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative"
71+
coth* λ[i]) + λ[i] * α * -csch(λ[i] * α)^2
4772
else
4873
((λ[i] * coth* λ[i]) - λ[j] * coth* λ[j])) / (λ[i] - λ[j]))
4974
end
5075
end
5176
return J
5277
end
5378

54-
function ∂H∂θ(
79+
@views function ∂H∂θ_cache(
5580
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
56-
θ::AbstractVecOrMat{T},
57-
r::AbstractVecOrMat{T},
58-
) where {T}
59-
return ∂H∂θ_cache(h, θ, r)
60-
end
61-
function ∂H∂θ_cache(
62-
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
63-
θ::AbstractVecOrMat{T},
64-
r::AbstractVecOrMat{T};
65-
return_cache=false,
81+
θ::AbstractVector{T},
82+
r::AbstractVector{T};
6683
cache=nothing,
6784
) where {T}
68-
# Terms that only dependent on θ can be cached in θ-unchanged loops
69-
if isnothing(cache)
70-
ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ)
71-
H = h.metric.G(θ)
72-
∂H∂θ = h.metric.∂G∂θ(θ)
73-
74-
G, Q, λ, softabsλ = softabs(H, h.metric.map.α)
75-
76-
R = diagm(1 ./ softabsλ)
77-
78-
# softabsΛ = diagm(softabsλ)
79-
# M = inv(softabsΛ) * Q' * r
80-
# M = R * Q' * r # equiv to above but avoid inv
81-
85+
cache = @something cache begin
86+
log_density, log_density_gradient = h.∂ℓπ∂θ(θ)
87+
premetric = h.metric.G(θ)
88+
premetric_sensitivities = h.metric.∂G∂θ(θ)
89+
metric, Q, λ, softabsλ = softabs(premetric, h.metric.map.α)
8290
J = make_J(λ, h.metric.map.α)
8391

8492
#! Based on the two equations from the right column of Page 3 of Betancourt (2012)
85-
term_1_cached = Q * (R .* J) * Q'
86-
else
87-
ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache
88-
end
89-
d = length(∂ℓπ∂θ)
90-
D = diagm((Q' * r) ./ softabsλ)
91-
term_2_cached = Q * D * J * D * Q'
92-
g =
93-
-mapreduce(vcat, 1:d) do i
94-
∂H∂θᵢ = ∂H∂θ[:, :, i]
95-
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
96-
# NOTE Some further optimization can be done here: cache the 1st product all together
97-
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
93+
tmpv = diag(J) ./ softabsλ
94+
tmpm = Q * Diagonal(tmpv) * Q'
95+
96+
rv1 = map(eachindex(log_density_gradient)) do i
97+
-log_density_gradient[i] + .5 * tr_product(tmpm, premetric_sensitivities[:, :, i])
9898
end
99+
(;log_density, Q, softabsλ, tmpv, tmpm, rv1)
100+
end
101+
cache.tmpv .= (cache.Q' * r) ./ cache.softabsλ
102+
cache.tmpm .= Q * (J .* cache.tmpv .* cache.tmpv') * Q'
99103

100-
dv = DualValue(ℓπ, g)
101-
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
104+
return DualValue(
105+
cache.log_density,
106+
cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i
107+
.5 * tr_product(cache.tmpm, cache.premetric_sensitivities[:, :, i])
108+
end
109+
), cache
102110
end
103111

104112
# QUES Do we want to change everything to position dependent by default?

0 commit comments

Comments
 (0)