@@ -24,19 +24,32 @@ using Statistics
2424# ###
2525
2626function gen_hess_fwd (func, x:: AbstractVector )
27+ cfg = ForwardDiff. HessianConfig (func, x)
28+ H = Matrix {eltype(x)} (undef, length (x), length (x))
29+
2730 function hess (x:: AbstractVector )
28- return nothing , nothing , ForwardDiff. hessian (func, x)
31+ ForwardDiff. hessian! (H, func, x, cfg)
32+ return H
2933 end
3034 return hess
3135end
3236
3337function gen_∂G∂θ_fwd (Vfunc, x; f= identity)
34- _Hfunc = gen_hess_fwd (Vfunc, x)
35- Hfunc = x -> _Hfunc (x)[3 ]
36- cfg = ForwardDiff. JacobianConfig (Hfunc, x)
38+ chunk = ForwardDiff. Chunk (x)
39+ tag = ForwardDiff. Tag (Vfunc, eltype (x))
40+ jac_cfg = ForwardDiff. JacobianConfig (Vfunc, x, chunk, tag)
41+ hess_cfg = ForwardDiff. HessianConfig (Vfunc, jac_cfg. duals, chunk, tag)
42+
3743 d = length (x)
3844 out = zeros (eltype (x), d^ 2 , d)
39- return x -> ForwardDiff. jacobian! (out, Hfunc, x, cfg)
45+
46+ function ∂G∂θ_fwd (y)
47+ hess = z -> ForwardDiff. hessian (Vfunc, z, hess_cfg, Val {false} ())
48+ ForwardDiff. jacobian! (out, hess, y, jac_cfg, Val {false} ())
49+ return out
50+ end
51+
52+ return ∂G∂θ_fwd
4053end
4154
4255function reshape_∂G∂θ (H)
4659
4760function prepare_sample (ℓπ, initial_θ, λ)
4861 Vfunc = x -> - ℓπ (x)
49- _Hfunc = MCMCLogDensityProblems . gen_hess (Vfunc, initial_θ)
62+ _Hfunc = gen_hess_fwd (Vfunc, initial_θ)
5063 Hfunc = x -> copy .(_Hfunc (x))
5164
5265 fstabilize = H -> H + λ * I
5366 Gfunc = x -> begin
54- H = fstabilize (Hfunc (x)[ 3 ] )
67+ H = fstabilize (Hfunc (x))
5568 all (isfinite, H) ? H : diagm (ones (length (x)))
5669 end
5770 _∂G∂θfunc = gen_∂G∂θ_fwd (x -> - ℓπ (x), initial_θ; f= fstabilize)
0 commit comments