Skip to content

Commit 11d6b83

Browse files
Some experimental changes to AD implementations
1 parent 380344e commit 11d6b83

7 files changed

+626
-550
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ OptimizationFiniteDiffExt = "FiniteDiff"
3636
OptimizationForwardDiffExt = "ForwardDiff"
3737
OptimizationMTKExt = "ModelingToolkit"
3838
OptimizationReverseDiffExt = "ReverseDiff"
39-
OptimizationSparseFiniteDiffExt = ["SparseDiffTools", "FiniteDiff", "Symbolics"]
40-
OptimizationSparseForwardDiffExt = ["SparseDiffTools", "ForwardDiff", "Symbolics"]
39+
OptimizationSparseDiffExt = ["SparseDiffTools", "Symbolics", "ReverseDiff"]
4140
OptimizationTrackerExt = "Tracker"
4241
OptimizationZygoteExt = "Zygote"
4342

ext/OptimizationEnzymeExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
156156
if f.grad === nothing
157157
function grad(res, θ, args...)
158158
res .= zero(eltype(res))
159+
println("objgrad")
159160
Enzyme.autodiff(Enzyme.Reverse,
160161
Const(firstapply),
161162
Active,

ext/OptimizationReverseDiffExt.jl

Lines changed: 82 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module OptimizationReverseDiffExt
33
import Optimization
44
import Optimization.SciMLBase: OptimizationFunction
55
import Optimization.ADTypes: AutoReverseDiff
6+
# using SparseDiffTools, Symbolics
67
isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) :
78
(using ..ReverseDiff, ..ReverseDiff.ForwardDiff)
89

@@ -20,7 +21,8 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
2021

2122
if f.hess === nothing
2223
hess = function (res, θ, args...)
23-
res .= ForwardDiff.jacobian(θ) do θ
24+
25+
res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
2426
ReverseDiff.gradient(x -> _f(x, args...), θ)
2527
end
2628
end
@@ -56,10 +58,10 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
5658
end
5759

5860
if cons !== nothing && f.cons_h === nothing
59-
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
61+
6062
cons_h = function (res, θ)
6163
for i in 1:num_cons
62-
res[i] .= ForwardDiff.jacobian) do θ
64+
res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, ) do θ
6365
ReverseDiff.gradient(fncs[i], θ)
6466
end
6567
end
@@ -81,79 +83,82 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
8183
lag_h, f.lag_hess_prototype)
8284
end
8385

84-
function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
85-
adtype::AutoReverseDiff, num_cons = 0)
86-
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))
87-
88-
if f.grad === nothing
89-
cfg = ReverseDiff.GradientConfig(cache.u0)
90-
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
91-
else
92-
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
93-
end
94-
95-
if f.hess === nothing
96-
hess = function (res, θ, args...)
97-
res .= ForwardDiff.jacobian(θ) do θ
98-
ReverseDiff.gradient(x -> _f(x, args...), θ)
99-
end
100-
end
101-
else
102-
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
103-
end
104-
105-
if f.hv === nothing
106-
hv = function (H, θ, v, args...)
107-
= ForwardDiff.Dual.(θ, v)
108-
res = similar(_θ)
109-
grad(res, _θ, args...)
110-
H .= getindex.(ForwardDiff.partials.(res), 1)
111-
end
112-
else
113-
hv = f.hv
114-
end
115-
116-
if f.cons === nothing
117-
cons = nothing
118-
else
119-
cons = (res, θ) -> f.cons(res, θ, cache.p)
120-
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
121-
end
122-
123-
if cons !== nothing && f.cons_j === nothing
124-
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
125-
cons_j = function (J, θ)
126-
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
127-
end
128-
else
129-
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
130-
end
131-
132-
if cons !== nothing && f.cons_h === nothing
133-
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
134-
cons_h = function (res, θ)
135-
for i in 1:num_cons
136-
res[i] .= ForwardDiff.jacobian(θ) do θ
137-
ReverseDiff.gradient(fncs[i], θ)
138-
end
139-
end
140-
end
141-
else
142-
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
143-
end
144-
145-
if f.lag_h === nothing
146-
lag_h = nothing # Consider implementing this
147-
else
148-
lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p)
149-
end
150-
151-
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
152-
cons = cons, cons_j = cons_j, cons_h = cons_h,
153-
hess_prototype = f.hess_prototype,
154-
cons_jac_prototype = f.cons_jac_prototype,
155-
cons_hess_prototype = f.cons_hess_prototype,
156-
lag_h, f.lag_hess_prototype)
157-
end
86+
# function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
87+
# adtype::AutoReverseDiff, num_cons = 0)
88+
# _f = (θ, args...) -> first(f.f(θ, cache.p, args...))
89+
90+
# if f.grad === nothing
91+
# grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
92+
# else
93+
# grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
94+
# end
95+
96+
# if f.hess === nothing
97+
# hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
98+
# hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
99+
# hess = function (res, θ, args...)
100+
# res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
101+
# ReverseDiff.gradient(x -> _f(x, args...), θ)
102+
# end
103+
# end
104+
# else
105+
# hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
106+
# end
107+
108+
# if f.hv === nothing
109+
# hv = function (H, θ, v, args...)
110+
# _θ = ForwardDiff.Dual.(θ, v)
111+
# res = similar(_θ)
112+
# grad(res, _θ, args...)
113+
# H .= getindex.(ForwardDiff.partials.(res), 1)
114+
# end
115+
# else
116+
# hv = f.hv
117+
# end
118+
119+
# if f.cons === nothing
120+
# cons = nothing
121+
# else
122+
# cons = (res, θ) -> f.cons(res, θ, cache.p)
123+
# cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
124+
# end
125+
126+
# if cons !== nothing && f.cons_j === nothing
127+
# cjconfig = ReverseDiff.JacobianConfig(cache.u0)
128+
# cons_j = function (J, θ)
129+
# ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
130+
# end
131+
# else
132+
# cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
133+
# end
134+
135+
# if cons !== nothing && f.cons_h === nothing
136+
# fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
137+
# conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
138+
# conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
139+
# cons_h = function (res, θ)
140+
# for i in 1:num_cons
141+
# res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
142+
# ReverseDiff.gradient(fncs[i], θ)
143+
# end
144+
# end
145+
# end
146+
# else
147+
# cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
148+
# end
149+
150+
# if f.lag_h === nothing
151+
# lag_h = nothing # Consider implementing this
152+
# else
153+
# lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p)
154+
# end
155+
156+
# return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
157+
# cons = cons, cons_j = cons_j, cons_h = cons_h,
158+
# hess_prototype = f.hess_prototype,
159+
# cons_jac_prototype = f.cons_jac_prototype,
160+
# cons_hess_prototype = f.cons_hess_prototype,
161+
# lag_h, f.lag_hess_prototype)
162+
# end
158163

159164
end

0 commit comments

Comments
 (0)