Skip to content

Commit 5fdc20d

Browse files
committed
format
1 parent 74af669 commit 5fdc20d

File tree

4 files changed

+31
-19
lines changed

4 files changed

+31
-19
lines changed

src/AdvancedHMC.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@ module AdvancedHMC
22

33
using Statistics: mean, var, middle
44
using LinearAlgebra:
5-
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, diagm, cholesky, UniformScaling, logdet, tr
5+
Symmetric,
6+
UpperTriangular,
7+
mul!,
8+
ldiv!,
9+
dot,
10+
I,
11+
diag,
12+
diagm,
13+
cholesky,
14+
UniformScaling,
15+
logdet,
16+
tr
617
using StatsFuns: logaddexp, logsumexp, loghalf
718
using Random: Random, AbstractRNG
819
using ProgressMeter: ProgressMeter

src/riemannian/hamiltonian.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#! Eq (14) of Girolami & Calderhead (2011)
22
function ∂H∂r(
3-
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, θ::AbstractVecOrMat, r::AbstractVecOrMat
3+
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic},
4+
θ::AbstractVecOrMat,
5+
r::AbstractVecOrMat,
46
)
57
H = h.metric.G(θ)
68
G = h.metric.map(H)
@@ -87,15 +89,16 @@ function ∂H∂θ_cache(
8789
d = length(∂ℓπ∂θ)
8890
D = diagm((Q' * r) ./ softabsλ)
8991
term_2_cached = Q * D * J * D * Q'
90-
g =
91-
isdiag ?
92-
-(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ)) :
92+
g = if isdiag
93+
-(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ))
94+
else
9395
-mapreduce(vcat, 1:d) do i
94-
∂H∂θᵢ = ∂H∂θ[:, :, i]
95-
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
96-
# NOTE Some further optimization can be done here: cache the 1st product all together
97-
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
98-
end
96+
∂H∂θᵢ = ∂H∂θ[:, :, i]
97+
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
98+
# NOTE Some further optimization can be done here: cache the 1st product all together
99+
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
100+
end
101+
end
99102

100103
dv = DualValue(ℓπ, g)
101104
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv

src/riemannian/metric.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap())
4444
end
4545

4646
# Convenient constructor
47-
function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map = IdentityMap())
47+
function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map=IdentityMap())
4848
_Hfunc = VecTargets.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian)
4949
Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug
5050

@@ -96,7 +96,9 @@ end
9696

9797
Base.size(e::DenseRiemannianMetric) = e.size
9898
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
99-
Base.show(io::IO, drm::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric")
99+
function Base.show(io::IO, drm::DenseRiemannianMetric)
100+
return print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric")
101+
end
100102

101103
function rand_momentum(
102104
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},

test/riemannian.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ using LinearAlgebra
2525
# Run the sampler to draw samples from the specified Gaussian, where
2626
# - `samples` will store the samples
2727
# - `stats` will store diagnostic statistics for each sample
28-
samples, stats = sample(
29-
rng, hamiltonian, kernel, initial_θ, n_samples; progress=true
30-
)
28+
samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true)
3129
@test length(samples) == n_samples
3230
@test length(stats) == n_samples
3331
end
@@ -55,9 +53,7 @@ end
5553
# Run the sampler to draw samples from the specified Gaussian, where
5654
# - `samples` will store the samples
5755
# - `stats` will store diagnostic statistics for each sample
58-
samples, stats = sample(
59-
rng, hamiltonian, kernel, initial_θ, n_samples; progress=true
60-
)
56+
samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true)
6157
@test length(samples) == n_samples
6258
@test length(stats) == n_samples
63-
end
59+
end

0 commit comments

Comments
 (0)