Skip to content

Commit c1ccddf

Browse files
fix
1 parent 5c8c889 commit c1ccddf

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

ext/OptimizationReverseDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
9090

9191
if f.hess === nothing
9292
hess = function (res, θ, args...)
93-
res .= ReverseDiff.gradient(x -> _f(x, args...), θ)
93+
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
9494
end
9595
else
9696
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)

test/ADtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ H3 = [Array{Float64}(undef, 2, 2)]
8484
optprob.cons_h(H3, x0)
8585
@test H3 == [[2.0 0.0; 0.0 2.0]]
8686

87+
G2 = Array{Float64}(undef, 2)
88+
H2 = Array{Float64}(undef, 2, 2)
89+
8790
optf = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme(), cons = con2_c)
8891
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoEnzyme(),
8992
nothing, 2)
@@ -101,6 +104,9 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
101104
optprob.cons_h(H3, x0)
102105
H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
103106

107+
G2 = Array{Float64}(undef, 2)
108+
H2 = Array{Float64}(undef, 2, 2)
109+
104110
optf = OptimizationFunction(rosenbrock, Optimization.AutoReverseDiff(), cons = con2_c)
105111
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoReverseDiff(),
106112
nothing, 2)
@@ -118,6 +124,9 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
118124
optprob.cons_h(H3, x0)
119125
H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
120126

127+
G2 = Array{Float64}(undef, 2)
128+
H2 = Array{Float64}(undef, 2, 2)
129+
121130
optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote(), cons = con2_c)
122131
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoZygote(),
123132
nothing, 2)

0 commit comments

Comments
 (0)