Skip to content

Commit d98580b

Browse files
Fix sparse jacobians and add sparsereversediff backend
1 parent 11d6b83 commit d98580b

File tree

3 files changed

+184
-102
lines changed

3 files changed

+184
-102
lines changed

ext/OptimizationForwardDiffExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
107107

108108
if f.hess === nothing
109109
hesscfg = ForwardDiff.HessianConfig(_f, cache.u0, ForwardDiff.Chunk{chunksize}())
110-
hess = (res, θ, args...) -> ForwardDiff.hessian!(res, x -> _f(x, args...), θ,
111-
hesscfg, Val{false}())
110+
hess = (res, θ, args...) -> (ForwardDiff.hessian!(res, x -> _f(x, args...), θ,
111+
hesscfg, Val{false}()))
112112
else
113113
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
114114
end
@@ -135,6 +135,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
135135
ForwardDiff.Chunk{chunksize}())
136136
cons_j = function (J, θ)
137137
ForwardDiff.jacobian!(J, cons_oop, θ, cjconfig)
138+
println(J)
138139
end
139140
else
140141
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
@@ -149,6 +150,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
149150
for i in 1:num_cons
150151
ForwardDiff.hessian!(res[i], fncs[i], θ, hess_config_cache[i], Val{true}())
151152
end
153+
# println(res)
152154
end
153155
else
154156
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)

ext/OptimizationReverseDiffExt.jl

Lines changed: 80 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
2020
end
2121

2222
if f.hess === nothing
23+
2324
hess = function (res, θ, args...)
24-
25-
res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
26-
ReverseDiff.gradient(x -> _f(x, args...), θ)
27-
end
25+
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
2826
end
2927
else
3028
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
@@ -61,9 +59,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
6159

6260
cons_h = function (res, θ)
6361
for i in 1:num_cons
64-
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, ) do θ
65-
ReverseDiff.gradient(fncs[i], θ)
66-
end
62+
ReverseDiff.gradient(res[i], fncs[i], θ)
6763
end
6864
end
6965
else
@@ -83,82 +79,82 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
8379
lag_h, f.lag_hess_prototype)
8480
end
8581

86-
# function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
87-
# adtype::AutoReverseDiff, num_cons = 0)
88-
# _f = (θ, args...) -> first(f.f(θ, cache.p, args...))
89-
90-
# if f.grad === nothing
91-
# grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
92-
# else
93-
# grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
94-
# end
95-
96-
# if f.hess === nothing
97-
# hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
98-
# hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
99-
# hess = function (res, θ, args...)
100-
# res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
101-
# ReverseDiff.gradient(x -> _f(x, args...), θ)
102-
# end
103-
# end
104-
# else
105-
# hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
106-
# end
107-
108-
# if f.hv === nothing
109-
# hv = function (H, θ, v, args...)
110-
# _θ = ForwardDiff.Dual.(θ, v)
111-
# res = similar(_θ)
112-
# grad(res, _θ, args...)
113-
# H .= getindex.(ForwardDiff.partials.(res), 1)
114-
# end
115-
# else
116-
# hv = f.hv
117-
# end
118-
119-
# if f.cons === nothing
120-
# cons = nothing
121-
# else
122-
# cons = (res, θ) -> f.cons(res, θ, cache.p)
123-
# cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
124-
# end
125-
126-
# if cons !== nothing && f.cons_j === nothing
127-
# cjconfig = ReverseDiff.JacobianConfig(cache.u0)
128-
# cons_j = function (J, θ)
129-
# ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
130-
# end
131-
# else
132-
# cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
133-
# end
134-
135-
# if cons !== nothing && f.cons_h === nothing
136-
# fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
137-
# conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
138-
# conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
139-
# cons_h = function (res, θ)
140-
# for i in 1:num_cons
141-
# res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
142-
# ReverseDiff.gradient(fncs[i], θ)
143-
# end
144-
# end
145-
# end
146-
# else
147-
# cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
148-
# end
149-
150-
# if f.lag_h === nothing
151-
# lag_h = nothing # Consider implementing this
152-
# else
153-
# lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p)
154-
# end
155-
156-
# return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
157-
# cons = cons, cons_j = cons_j, cons_h = cons_h,
158-
# hess_prototype = f.hess_prototype,
159-
# cons_jac_prototype = f.cons_jac_prototype,
160-
# cons_hess_prototype = f.cons_hess_prototype,
161-
# lag_h, f.lag_hess_prototype)
162-
# end
82+
function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
83+
adtype::AutoReverseDiff, num_cons = 0)
84+
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))
85+
86+
if f.grad === nothing
87+
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
88+
else
89+
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
90+
end
91+
92+
if f.hess === nothing
93+
hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
94+
hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
95+
hess = function (res, θ, args...)
96+
res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
97+
ReverseDiff.gradient(x -> _f(x, args...), θ)
98+
end
99+
end
100+
else
101+
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
102+
end
103+
104+
if f.hv === nothing
105+
hv = function (H, θ, v, args...)
106+
= ForwardDiff.Dual.(θ, v)
107+
res = similar(_θ)
108+
grad(res, _θ, args...)
109+
H .= getindex.(ForwardDiff.partials.(res), 1)
110+
end
111+
else
112+
hv = f.hv
113+
end
114+
115+
if f.cons === nothing
116+
cons = nothing
117+
else
118+
cons = (res, θ) -> f.cons(res, θ, cache.p)
119+
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
120+
end
121+
122+
if cons !== nothing && f.cons_j === nothing
123+
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
124+
cons_j = function (J, θ)
125+
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
126+
end
127+
else
128+
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
129+
end
130+
131+
if cons !== nothing && f.cons_h === nothing
132+
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
133+
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
134+
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
135+
cons_h = function (res, θ)
136+
for i in 1:num_cons
137+
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
138+
ReverseDiff.gradient(fncs[i], θ)
139+
end
140+
end
141+
end
142+
else
143+
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
144+
end
145+
146+
if f.lag_h === nothing
147+
lag_h = nothing # Consider implementing this
148+
else
149+
lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p)
150+
end
151+
152+
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
153+
cons = cons, cons_j = cons_j, cons_h = cons_h,
154+
hess_prototype = f.hess_prototype,
155+
cons_jac_prototype = f.cons_jac_prototype,
156+
cons_hess_prototype = f.cons_hess_prototype,
157+
lag_h, f.lag_hess_prototype)
158+
end
163159

164160
end

0 commit comments

Comments
 (0)