@@ -3,6 +3,7 @@ module OptimizationReverseDiffExt
3
3
import Optimization
4
4
import Optimization. SciMLBase: OptimizationFunction
5
5
import Optimization. ADTypes: AutoReverseDiff
6
+ # using SparseDiffTools, Symbolics
6
7
isdefined (Base, :get_extension ) ? (using ReverseDiff, ReverseDiff. ForwardDiff) :
7
8
(using .. ReverseDiff, .. ReverseDiff. ForwardDiff)
8
9
@@ -20,7 +21,8 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
20
21
21
22
if f. hess === nothing
22
23
hess = function (res, θ, args... )
23
- res .= ForwardDiff. jacobian (θ) do θ
24
+
25
+ res .= SparseDiffTools. forwarddiff_color_jacobian (θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
24
26
ReverseDiff. gradient (x -> _f (x, args... ), θ)
25
27
end
26
28
end
@@ -56,10 +58,10 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
56
58
end
57
59
58
60
if cons != = nothing && f. cons_h === nothing
59
- fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
61
+
60
62
cons_h = function (res, θ)
61
63
for i in 1 : num_cons
62
- res[i] .= ForwardDiff . jacobian (θ ) do θ
64
+ res[i] .= SparseDiffTools . forwarddiff_color_jacobian (θ, ) do θ
63
65
ReverseDiff. gradient (fncs[i], θ)
64
66
end
65
67
end
@@ -81,79 +83,82 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
81
83
lag_h, f. lag_hess_prototype)
82
84
end
83
85
84
- function Optimization. instantiate_function (f, cache:: Optimization.ReInitCache ,
85
- adtype:: AutoReverseDiff , num_cons = 0 )
86
- _f = (θ, args... ) -> first (f. f (θ, cache. p, args... ))
87
-
88
- if f. grad === nothing
89
- cfg = ReverseDiff. GradientConfig (cache. u0)
90
- grad = (res, θ, args... ) -> ReverseDiff. gradient! (res, x -> _f (x, args... ), θ)
91
- else
92
- grad = (G, θ, args... ) -> f. grad (G, θ, cache. p, args... )
93
- end
94
-
95
- if f. hess === nothing
96
- hess = function (res, θ, args... )
97
- res .= ForwardDiff. jacobian (θ) do θ
98
- ReverseDiff. gradient (x -> _f (x, args... ), θ)
99
- end
100
- end
101
- else
102
- hess = (H, θ, args... ) -> f. hess (H, θ, cache. p, args... )
103
- end
104
-
105
- if f. hv === nothing
106
- hv = function (H, θ, v, args... )
107
- _θ = ForwardDiff. Dual .(θ, v)
108
- res = similar (_θ)
109
- grad (res, _θ, args... )
110
- H .= getindex .(ForwardDiff. partials .(res), 1 )
111
- end
112
- else
113
- hv = f. hv
114
- end
115
-
116
- if f. cons === nothing
117
- cons = nothing
118
- else
119
- cons = (res, θ) -> f. cons (res, θ, cache. p)
120
- cons_oop = (x) -> (_res = zeros (eltype (x), num_cons); cons (_res, x); _res)
121
- end
122
-
123
- if cons != = nothing && f. cons_j === nothing
124
- cjconfig = ReverseDiff. JacobianConfig (cache. u0)
125
- cons_j = function (J, θ)
126
- ReverseDiff. jacobian! (J, cons_oop, θ, cjconfig)
127
- end
128
- else
129
- cons_j = (J, θ) -> f. cons_j (J, θ, cache. p)
130
- end
131
-
132
- if cons != = nothing && f. cons_h === nothing
133
- fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
134
- cons_h = function (res, θ)
135
- for i in 1 : num_cons
136
- res[i] .= ForwardDiff. jacobian (θ) do θ
137
- ReverseDiff. gradient (fncs[i], θ)
138
- end
139
- end
140
- end
141
- else
142
- cons_h = (res, θ) -> f. cons_h (res, θ, cache. p)
143
- end
144
-
145
- if f. lag_h === nothing
146
- lag_h = nothing # Consider implementing this
147
- else
148
- lag_h = (res, θ, σ, μ) -> f. lag_h (res, θ, σ, μ, cache. p)
149
- end
150
-
151
- return OptimizationFunction {true} (f. f, adtype; grad = grad, hess = hess, hv = hv,
152
- cons = cons, cons_j = cons_j, cons_h = cons_h,
153
- hess_prototype = f. hess_prototype,
154
- cons_jac_prototype = f. cons_jac_prototype,
155
- cons_hess_prototype = f. cons_hess_prototype,
156
- lag_h, f. lag_hess_prototype)
157
- end
86
+ # function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
87
+ # adtype::AutoReverseDiff, num_cons = 0)
88
+ # _f = (θ, args...) -> first(f.f(θ, cache.p, args...))
89
+
90
+ # if f.grad === nothing
91
+ # grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
92
+ # else
93
+ # grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
94
+ # end
95
+
96
+ # if f.hess === nothing
97
+ # hess_sparsity = Symbolics.hessian_sparsity(_f, cache.u0)
98
+ # hess_colors = SparseDiffTools.matrix_colors(tril(hess_sparsity))
99
+ # hess = function (res, θ, args...)
100
+ # res .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = hess_colors, sparsity = hess_sparsity) do θ
101
+ # ReverseDiff.gradient(x -> _f(x, args...), θ)
102
+ # end
103
+ # end
104
+ # else
105
+ # hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
106
+ # end
107
+
108
+ # if f.hv === nothing
109
+ # hv = function (H, θ, v, args...)
110
+ # _θ = ForwardDiff.Dual.(θ, v)
111
+ # res = similar(_θ)
112
+ # grad(res, _θ, args...)
113
+ # H .= getindex.(ForwardDiff.partials.(res), 1)
114
+ # end
115
+ # else
116
+ # hv = f.hv
117
+ # end
118
+
119
+ # if f.cons === nothing
120
+ # cons = nothing
121
+ # else
122
+ # cons = (res, θ) -> f.cons(res, θ, cache.p)
123
+ # cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
124
+ # end
125
+
126
+ # if cons !== nothing && f.cons_j === nothing
127
+ # cjconfig = ReverseDiff.JacobianConfig(cache.u0)
128
+ # cons_j = function (J, θ)
129
+ # ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
130
+ # end
131
+ # else
132
+ # cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
133
+ # end
134
+
135
+ # if cons !== nothing && f.cons_h === nothing
136
+ # fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
137
+ # conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
138
+ # conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
139
+ # cons_h = function (res, θ)
140
+ # for i in 1:num_cons
141
+ # res[i] .= SparseDiffTools.forwarddiff_color_jacobian(θ, colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) do θ
142
+ # ReverseDiff.gradient(fncs[i], θ)
143
+ # end
144
+ # end
145
+ # end
146
+ # else
147
+ # cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
148
+ # end
149
+
150
+ # if f.lag_h === nothing
151
+ # lag_h = nothing # Consider implementing this
152
+ # else
153
+ # lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p)
154
+ # end
155
+
156
+ # return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
157
+ # cons = cons, cons_j = cons_j, cons_h = cons_h,
158
+ # hess_prototype = f.hess_prototype,
159
+ # cons_jac_prototype = f.cons_jac_prototype,
160
+ # cons_hess_prototype = f.cons_hess_prototype,
161
+ # lag_h, f.lag_hess_prototype)
162
+ # end
158
163
159
164
end
0 commit comments