Skip to content

Commit af46f2e

Browse files
committed
Prevent test type instability
1 parent 7e91495 commit af46f2e

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

test/riemannian.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,32 @@ using Statistics
2424
####
2525

2626
function gen_hess_fwd(func, x::AbstractVector)
27+
cfg = ForwardDiff.HessianConfig(func, x)
28+
H = Matrix{eltype(x)}(undef, length(x), length(x))
29+
2730
function hess(x::AbstractVector)
28-
return nothing, nothing, ForwardDiff.hessian(func, x)
31+
ForwardDiff.hessian!(H, func, x, cfg)
32+
return H
2933
end
3034
return hess
3135
end
3236

3337
function gen_∂G∂θ_fwd(Vfunc, x; f=identity)
34-
_Hfunc = gen_hess_fwd(Vfunc, x)
35-
Hfunc = x -> _Hfunc(x)[3]
36-
cfg = ForwardDiff.JacobianConfig(Hfunc, x)
38+
chunk = ForwardDiff.Chunk(x)
39+
tag = ForwardDiff.Tag(Vfunc, eltype(x))
40+
jac_cfg = ForwardDiff.JacobianConfig(Vfunc, x, chunk, tag)
41+
hess_cfg = ForwardDiff.HessianConfig(Vfunc, jac_cfg.duals, chunk, tag)
42+
3743
d = length(x)
3844
out = zeros(eltype(x), d^2, d)
39-
return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg)
45+
46+
function ∂G∂θ_fwd(y)
47+
hess = z -> ForwardDiff.hessian(Vfunc, z, hess_cfg, Val{false}())
48+
ForwardDiff.jacobian!(out, hess, y, jac_cfg, Val{false}())
49+
return out
50+
end
51+
52+
return ∂G∂θ_fwd
4053
end
4154

4255
function reshape_∂G∂θ(H)
@@ -46,12 +59,12 @@ end
4659

4760
function prepare_sample(ℓπ, initial_θ, λ)
4861
Vfunc = x -> -ℓπ(x)
49-
_Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, initial_θ)
62+
_Hfunc = gen_hess_fwd(Vfunc, initial_θ)
5063
Hfunc = x -> copy.(_Hfunc(x))
5164

5265
fstabilize = H -> H + λ * I
5366
Gfunc = x -> begin
54-
H = fstabilize(Hfunc(x)[3])
67+
H = fstabilize(Hfunc(x))
5568
all(isfinite, H) ? H : diagm(ones(length(x)))
5669
end
5770
_∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize)

0 commit comments

Comments
 (0)