Skip to content

Commit b0860b8

Browse files
SebastianM-Cclaude
andcommitted
fix Enzyme HVP returning zeros instead of correct Hessian-vector product
The Enzyme extension's Hessian-vector product (HVP) implementation was incorrectly using `Enzyme.make_zero(x)` which zeroed out the tangent direction vector `v`, causing the forward-mode AD to have no direction to differentiate in. This resulted in HVP always returning zeros. Fixed by using the correct forward-over-reverse AD approach with the existing `inner_grad` function, which properly computes H*v by taking the gradient ∇f(θ) in reverse mode and differentiating it in forward mode along direction v. Fixes both in-place (OptimizationFunction{true}) and out-of-place (OptimizationFunction{false}) versions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 6839663 commit b0860b8

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

lib/OptimizationBase/ext/OptimizationEnzymeExt.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,17 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
199199

200200
if hv == true && f.hv === nothing
201201
function hv!(H, θ, v, p = p)
202-
x = Duplicated(θ, v)
203-
dx = Enzyme.make_zero(x)
204-
H .= Enzyme.autodiff(
205-
fmode, hv_f2_alloc, Const(rmode), Duplicated(x,dx),
206-
Const(f.f), Const(p)
207-
)[1].dval
202+
= zero(θ)
203+
Enzyme.make_zero!(H)
204+
Enzyme.autodiff(
205+
fmode,
206+
inner_grad,
207+
Const(rmode),
208+
Duplicated(θ, v),
209+
Duplicated(dθ, H),
210+
Const(f.f),
211+
Const(p)
212+
)
208213
end
209214
elseif hv == true
210215
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
@@ -553,13 +558,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
553558
end
554559

555560
if hv == true && f.hv === nothing
561+
H = zero(x)
556562
function hv!(θ, v, p = p)
557-
x = Duplicated(θ, v)
558-
dx = Enzyme.make_zero(x)
559-
return Enzyme.autodiff(
560-
fmode, hv_f2_alloc, DuplicatedNoNeed, Const(rmode), Duplicated(x, dx),
561-
Const(_f), Const(f.f), Const(p)
562-
)[1].dval
563+
= zero(θ)
564+
Enzyme.make_zero!(H)
565+
Enzyme.autodiff(
566+
fmode,
567+
inner_grad,
568+
Const(rmode),
569+
Duplicated(θ, v),
570+
Duplicated(dθ, H),
571+
Const(f.f),
572+
Const(p)
573+
)
574+
return H
563575
end
564576
elseif hv == true
565577
hv! = (θ, v, p = p) -> f.hv(θ, v, p)

0 commit comments

Comments
 (0)