@@ -2,16 +2,60 @@ using ReTest, Random
2
2
using AdvancedHMC, ForwardDiff, AbstractMCMC
3
3
using LinearAlgebra
4
4
5
+ using Pkg
6
+ Pkg. develop (; url= " https://github.com/chalk-lab/MCMCLogDensityProblems.jl" )
7
+ using MCMCLogDensityProblems
8
+
9
+ # Fisher information metric
10
+ function gen_∂G∂θ_fwd (Vfunc, x; f= identity)
11
+ _Hfunc = gen_hess_fwd (Vfunc, x)
12
+ Hfunc = x -> _Hfunc (x)[3 ]
13
+ # QUES What's the best output format of this function?
14
+ cfg = ForwardDiff. JacobianConfig (Hfunc, x)
15
+ d = length (x)
16
+ out = zeros (eltype (x), d^ 2 , d)
17
+ return x -> ForwardDiff. jacobian! (out, Hfunc, x, cfg)
18
+ return out # default output shape [∂H∂x₁; ∂H∂x₂; ...]
19
+ end
20
+
21
+ function gen_hess_fwd (func, x:: AbstractVector )
22
+ function hess (x:: AbstractVector )
23
+ return nothing , nothing , ForwardDiff. hessian (func, x)
24
+ end
25
+ return hess
26
+ end
27
+
28
+ function reshape_∂G∂θ (H)
29
+ d = size (H, 2 )
30
+ return cat ((H[((i - 1 ) * d + 1 ): (i * d), :] for i in 1 : d). .. ; dims= 3 )
31
+ end
32
+
33
+ function prepare_sample (ℓπ, initial_θ, λ)
34
+ _Hfunc = MCMCLogDensityProblems. gen_hess (x -> - ℓπ (x), initial_θ) # x -> (value, gradient, hessian)
35
+ Hfunc = x -> copy .(_Hfunc (x)) # _Hfunc do in-place computation, copy to avoid bug
36
+
37
+ fstabilize = H -> H + λ * I
38
+ Gfunc = x -> begin
39
+ H = fstabilize (Hfunc (x)[3 ])
40
+ all (isfinite, H) ? H : diagm (ones (length (x)))
41
+ end
42
+ _∂G∂θfunc = gen_∂G∂θ_fwd (x -> - ℓπ (x), initial_θ; f= fstabilize)
43
+ ∂G∂θfunc = x -> reshape_∂G∂θ (_∂G∂θfunc (x))
44
+
45
+ return Gfunc, ∂G∂θfunc
46
+ end
47
+
5
48
@testset " Multi variate Normal with Riemannian HMC" begin
6
49
# Set the number of samples to draw and warmup iterations
7
50
n_samples = 2_000
8
51
rng = MersenneTwister (1110 )
9
52
initial_θ = rand (rng, D)
10
53
λ = 1e-2
54
+ G, ∂G∂θ = prepare_sample (ℓπ, initial_θ, λ)
11
55
# Define a Hamiltonian system
12
- metric = DenseRiemannianMetric ((D,), ℓπ, initial_θ, λ )
56
+ metric = DenseRiemannianMetric ((D,), G, ∂G∂θ )
13
57
kinetic = GaussianKinetic ()
14
- hamiltonian = Hamiltonian (metric, kinetic, ℓπ, ∇ℓπ )
58
+ hamiltonian = Hamiltonian (metric, kinetic, ℓπ, ∂ℓπ∂θ )
15
59
16
60
# Define a leapfrog solver, with the initial step size chosen heuristically
17
61
initial_ϵ = 0.01
35
79
n_samples = 2_000
36
80
rng = MersenneTwister (1110 )
37
81
initial_θ = rand (rng, D)
82
+ λ = 1e-2
83
+ G, ∂G∂θ = prepare_sample (ℓπ, initial_θ, λ)
38
84
39
85
# Define a Hamiltonian system
40
- metric = DenseRiemannianMetric ((D,), ℓπ, initial_θ , λSoftAbsMap (20.0 ))
86
+ metric = DenseRiemannianMetric ((D,), G, ∂G∂θ , λSoftAbsMap (20.0 ))
41
87
kinetic = GaussianKinetic ()
42
- hamiltonian = Hamiltonian (metric, kinetic, ℓπ, ∇ℓπ )
88
+ hamiltonian = Hamiltonian (metric, kinetic, ℓπ, ∂ℓπ∂θ )
43
89
44
90
# Define a leapfrog solver, with the initial step size chosen heuristically
45
91
initial_ϵ = 0.01
0 commit comments