Skip to content

Commit 609ad53

Browse files
Some fixes
1 parent d98580b commit 609ad53

File tree

5 files changed

+6
-17
lines changed

5 files changed

+6
-17
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
156156
if f.grad === nothing
157157
function grad(res, θ, args...)
158158
res .= zero(eltype(res))
159-
println("objgrad")
160159
Enzyme.autodiff(Enzyme.Reverse,
161160
Const(firstapply),
162161
Active,

ext/OptimizationForwardDiffExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
135135
ForwardDiff.Chunk{chunksize}())
136136
cons_j = function (J, θ)
137137
ForwardDiff.jacobian!(J, cons_oop, θ, cjconfig)
138-
println(J)
139138
end
140139
else
141140
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
@@ -150,7 +149,6 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
150149
for i in 1:num_cons
151150
ForwardDiff.hessian!(res[i], fncs[i], θ, hess_config_cache[i], Val{true}())
152151
end
153-
# println(res)
154152
end
155153
else
156154
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)

ext/OptimizationReverseDiffExt.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
5656
end
5757

5858
if cons !== nothing && f.cons_h === nothing
59-
59+
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
6060
cons_h = function (res, θ)
6161
for i in 1:num_cons
62-
ReverseDiff.gradient(res[i], fncs[i], θ)
62+
ReverseDiff.hessian!(res[i], fncs[i], θ)
6363
end
6464
end
6565
else
@@ -90,12 +90,8 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
9090
end
9191

9292
if f.hess === nothing
93-
hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
94-
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
9593
hess = function (res, θ, args...)
96-
res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
97-
ReverseDiff.gradient(x -> _f(x, args...), θ)
98-
end
94+
res .= ReverseDiff.gradient(x -> _f(x, args...), θ)
9995
end
10096
else
10197
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
@@ -130,13 +126,9 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
130126

131127
if cons !== nothing && f.cons_h === nothing
132128
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
133-
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
134-
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
135129
cons_h = function (res, θ)
136130
for i in 1:num_cons
137-
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
138-
ReverseDiff.gradient(fncs[i], θ)
139-
end
131+
ReverseDiff.hessian!(res[i], fncs[i], θ)
140132
end
141133
end
142134
else

ext/OptimizationSparseDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OptimizationSparseDiffExt
22

33
import Optimization, Optimization.ArrayInterface
44
import Optimization.SciMLBase: OptimizationFunction
5-
import Optimization.ADTypes: AutoSparseForwardDiff, AutoSparseFiniteDiff, AutoReverseDiff
5+
import Optimization.ADTypes: AutoSparseForwardDiff, AutoSparseFiniteDiff, AutoSparseReverseDiff
66
using Optimization.LinearAlgebra, ReverseDiff
77
isdefined(Base, :get_extension) ? (using SparseDiffTools, SparseDiffTools.ForwardDiff, SparseDiffTools.FiniteDiff, Symbolics) :
88
(using ..SparseDiffTools, ..SparseDiffTools.ForwardDiff, ..SparseDiffTools.FiniteDiff, ..Symbolics)

ext/OptimizationZygoteExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
114114
if f.cons === nothing
115115
cons = nothing
116116
else
117-
cons = (res, θ) -> f.cons(res, θ, p)
117+
cons = (res, θ) -> f.cons(res, θ, cache.p)
118118
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
119119
end
120120

0 commit comments

Comments
 (0)