Skip to content

Commit 8a31df4

Browse files
committed
Dont hard dep on MCMCLogDensityProblems
1 parent 75c2380 commit 8a31df4

File tree

4 files changed

+50
-66
lines changed

4 files changed

+50
-66
lines changed

Project.toml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.7.1"
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1211
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -16,11 +15,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1615
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1716
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1817
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
19-
VecTargets = "8a639fad-7908-4fe4-8003-906e9297f002"
20-
21-
[sources]
22-
VecTargets = {url = "https://github.com/chalk-lab/VecTargets.jl", rev = "main"}
23-
2418

2519
[weakdeps]
2620
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -43,7 +37,6 @@ ArgCheck = "1, 2"
4337
ComponentArrays = "0.15"
4438
CUDA = "3, 4, 5"
4539
DocStringExtensions = "0.8, 0.9"
46-
ForwardDiff = "0.10.38"
4740
LinearAlgebra = "<0.1, 1"
4841
LogDensityProblems = "2"
4942
LogDensityProblemsAD = "1"

src/AdvancedHMC.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,8 @@ using LogDensityProblemsAD: LogDensityProblemsAD
3030

3131
using AbstractMCMC: AbstractMCMC, LogDensityModel
3232

33-
using VecTargets: VecTargets
34-
3533
import StatsBase: sample
3634

37-
using ForwardDiff: ForwardDiff
38-
3935
const DEFAULT_FLOAT_TYPE = typeof(float(0))
4036

4137
include("utilities.jl")

src/riemannian/metric.jl

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -43,57 +43,6 @@ function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap())
4343
return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp)
4444
end
4545

46-
# Convenient constructor
47-
function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map=IdentityMap())
48-
_Hfunc = VecTargets.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian)
49-
Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug
50-
51-
fstabilize = H -> H + λ * I
52-
Gfunc = x -> begin
53-
H = fstabilize(Hfunc(x)[3])
54-
all(isfinite, H) ? H : diagm(ones(length(x)))
55-
end
56-
_∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize)
57-
∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x))
58-
59-
_temp = Vector{Float64}(undef, first(size))
60-
61-
return DenseRiemannianMetric(size, Gfunc, ∂G∂θfunc, map, _temp)
62-
end
63-
64-
function gen_hess_fwd(func, x::AbstractVector)
65-
function hess(x::AbstractVector)
66-
return nothing, nothing, ForwardDiff.hessian(func, x)
67-
end
68-
return hess
69-
end
70-
71-
#= possible integrate DI for AD-independent fisher information metric
72-
function gen_∂G∂θ_rev(Vfunc, x; f=identity)
73-
_Hfunc = VecTargets.gen_hess(Vfunc, ReverseDiff.track.(x))
74-
Hfunc = x -> _Hfunc(x)[3]
75-
# QUES What's the best output format of this function?
76-
return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...]
77-
end
78-
=#
79-
80-
# Fisher information metric
81-
function gen_∂G∂θ_fwd(Vfunc, x; f=identity)
82-
_Hfunc = gen_hess_fwd(Vfunc, x)
83-
Hfunc = x -> _Hfunc(x)[3]
84-
# QUES What's the best output format of this function?
85-
cfg = ForwardDiff.JacobianConfig(Hfunc, x)
86-
d = length(x)
87-
out = zeros(eltype(x), d^2, d)
88-
return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg)
89-
return out # default output shape [∂H∂x₁; ∂H∂x₂; ...]
90-
end
91-
92-
function reshape_∂G∂θ(H)
93-
d = size(H, 2)
94-
return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3)
95-
end
96-
9746
Base.size(e::DenseRiemannianMetric) = e.size
9847
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
9948
function Base.show(io::IO, drm::DenseRiemannianMetric)

test/riemannian.jl

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,60 @@ using ReTest, Random
22
using AdvancedHMC, ForwardDiff, AbstractMCMC
33
using LinearAlgebra
44

5+
using Pkg
6+
Pkg.develop(; url="https://github.com/chalk-lab/MCMCLogDensityProblems.jl")
7+
using MCMCLogDensityProblems
8+
9+
# Fisher information metric
10+
function gen_∂G∂θ_fwd(Vfunc, x; f=identity)
11+
_Hfunc = gen_hess_fwd(Vfunc, x)
12+
Hfunc = x -> _Hfunc(x)[3]
13+
# QUES What's the best output format of this function?
14+
cfg = ForwardDiff.JacobianConfig(Hfunc, x)
15+
d = length(x)
16+
out = zeros(eltype(x), d^2, d)
17+
return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg)
18+
return out # default output shape [∂H∂x₁; ∂H∂x₂; ...]
19+
end
20+
21+
function gen_hess_fwd(func, x::AbstractVector)
22+
function hess(x::AbstractVector)
23+
return nothing, nothing, ForwardDiff.hessian(func, x)
24+
end
25+
return hess
26+
end
27+
28+
function reshape_∂G∂θ(H)
29+
d = size(H, 2)
30+
return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3)
31+
end
32+
33+
function prepare_sample(ℓπ, initial_θ, λ)
34+
_Hfunc = MCMCLogDensityProblems.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian)
35+
Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug
36+
37+
fstabilize = H -> H + λ * I
38+
Gfunc = x -> begin
39+
H = fstabilize(Hfunc(x)[3])
40+
all(isfinite, H) ? H : diagm(ones(length(x)))
41+
end
42+
_∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize)
43+
∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x))
44+
45+
return Gfunc, ∂G∂θfunc
46+
end
47+
548
@testset "Multi variate Normal with Riemannian HMC" begin
649
# Set the number of samples to draw and warmup iterations
750
n_samples = 2_000
851
rng = MersenneTwister(1110)
952
initial_θ = rand(rng, D)
1053
λ = 1e-2
54+
G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ)
1155
# Define a Hamiltonian system
12-
metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λ)
56+
metric = DenseRiemannianMetric((D,), G, ∂G∂θ)
1357
kinetic = GaussianKinetic()
14-
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ)
58+
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ)
1559

1660
# Define a leapfrog solver, with the initial step size chosen heuristically
1761
initial_ϵ = 0.01
@@ -35,11 +79,13 @@ end
3579
n_samples = 2_000
3680
rng = MersenneTwister(1110)
3781
initial_θ = rand(rng, D)
82+
λ = 1e-2
83+
G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ)
3884

3985
# Define a Hamiltonian system
40-
metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λSoftAbsMap(20.0))
86+
metric = DenseRiemannianMetric((D,), G, ∂G∂θ, λSoftAbsMap(20.0))
4187
kinetic = GaussianKinetic()
42-
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ)
88+
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ)
4389

4490
# Define a leapfrog solver, with the initial step size chosen heuristically
4591
initial_ϵ = 0.01

0 commit comments

Comments
 (0)