@@ -7,32 +7,60 @@ import Optimization.ADTypes: AutoReverseDiff
7
7
isdefined (Base, :get_extension ) ? (using ReverseDiff, ReverseDiff. ForwardDiff) :
8
8
(using .. ReverseDiff, .. ReverseDiff. ForwardDiff)
9
9
10
+ struct OptimizationReverseDiffTag end
11
+
10
12
function Optimization. instantiate_function (f, x, adtype:: AutoReverseDiff ,
11
13
p = SciMLBase. NullParameters (),
12
14
num_cons = 0 )
13
15
_f = (θ, args... ) -> first (f. f (θ, p, args... ))
14
16
15
17
if f. grad === nothing
16
- cfg = ReverseDiff. GradientConfig (x)
17
- grad = (res, θ, args... ) -> ReverseDiff. gradient! (res, x -> _f (x, args... ), θ, cfg)
18
+ if adtype. compile
19
+ _tape = ReverseDiff. GradientTape (_f, x)
20
+ tape = ReverseDiff. compile (_tape)
21
+ grad = function (res, θ, args... )
22
+ ReverseDiff. gradient! (res, tape, θ)
23
+ end
24
+ else
25
+ cfg = ReverseDiff. GradientConfig (x)
26
+ grad = (res, θ, args... ) -> ReverseDiff. gradient! (res, x -> _f (x, args... ), θ, cfg)
27
+ end
18
28
else
19
29
grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
20
30
end
21
31
22
32
if f. hess === nothing
23
- hess = function (res, θ, args... )
24
- ReverseDiff. hessian! (res, x -> _f (x, args... ), θ)
33
+ if adtype. compile
34
+ T = ForwardDiff. Tag (OptimizationReverseDiffTag (),eltype (x))
35
+ xdual = ForwardDiff. Dual {typeof(T),eltype(x),length(x)} .(x, Ref (ForwardDiff. Partials ((ones (eltype (x), length (x))... ,))))
36
+ h_tape = ReverseDiff. GradientTape (_f, xdual)
37
+ htape = ReverseDiff. compile (h_tape)
38
+ function g (θ)
39
+ res1 = zeros (eltype (θ), length (θ))
40
+ ReverseDiff. gradient! (res1, htape, θ)
41
+ end
42
+ jaccfg = ForwardDiff. JacobianConfig (g, x, ForwardDiff. Chunk (x), T)
43
+ hess = function (res, θ, args... )
44
+ ForwardDiff. jacobian! (res, g, θ, jaccfg, Val {false} ())
45
+ end
46
+ else
47
+ hess = function (res, θ, args... )
48
+ ReverseDiff. hessian! (res, x -> _f (x, args... ), θ)
49
+ end
25
50
end
26
51
else
27
52
hess = (H, θ, args... ) -> f. hess (H, θ, p, args... )
28
53
end
29
54
30
55
if f. hv === nothing
31
56
hv = function (H, θ, v, args... )
32
- _θ = ForwardDiff. Dual .(θ, v)
33
- res = similar (_θ)
34
- grad (res, _θ, args... )
35
- H .= getindex .(ForwardDiff. partials .(res), 1 )
57
+ # _θ = ForwardDiff.Dual.(θ, v)
58
+ # res = similar(_θ)
59
+ # grad(res, _θ, args...)
60
+ # H .= getindex.(ForwardDiff.partials.(res), 1)
61
+ res = zeros (length (θ), length (θ))
62
+ hess (res, θ, args... )
63
+ H .= res * v
36
64
end
37
65
else
38
66
hv = f. hv
@@ -46,19 +74,43 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
46
74
end
47
75
48
76
if cons != = nothing && f. cons_j === nothing
49
- cjconfig = ReverseDiff. JacobianConfig (x)
50
- cons_j = function (J, θ)
51
- ReverseDiff. jacobian! (J, cons_oop, θ, cjconfig)
77
+ if adtype. compile
78
+ _jac_tape = ReverseDiff. JacobianTape (cons_oop, x)
79
+ jac_tape = ReverseDiff. compile (_jac_tape)
80
+ cons_j = function (J, θ)
81
+ ReverseDiff. jacobian! (J, jac_tape, θ)
82
+ end
83
+ else
84
+ cjconfig = ReverseDiff. JacobianConfig (x)
85
+ cons_j = function (J, θ)
86
+ ReverseDiff. jacobian! (J, cons_oop, θ, cjconfig)
87
+ end
52
88
end
53
89
else
54
90
cons_j = (J, θ) -> f. cons_j (J, θ, p)
55
91
end
56
92
57
93
if cons != = nothing && f. cons_h === nothing
58
94
fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
59
- cons_h = function (res, θ)
60
- for i in 1 : num_cons
61
- ReverseDiff. hessian! (res[i], fncs[i], θ)
95
+ if adtype. compile
96
+ consh_tapes = ReverseDiff. GradientTape .(fncs, Ref (xdual))
97
+ conshtapes = ReverseDiff. compile .(consh_tapes)
98
+ function grad_cons (θ, htape)
99
+ res1 = zeros (eltype (θ), length (θ))
100
+ ReverseDiff. gradient! (res1, htape, θ)
101
+ end
102
+ gs = [x -> grad_cons (x, conshtapes[i]) for i in 1 : num_cons]
103
+ jaccfgs = [ForwardDiff. JacobianConfig (gs[i], x, ForwardDiff. Chunk (x), T) for i in 1 : num_cons]
104
+ cons_h = function (res, θ)
105
+ for i in 1 : num_cons
106
+ ForwardDiff. jacobian! (res[i], gs[i], θ, jaccfgs[i], Val {false} ())
107
+ end
108
+ end
109
+ else
110
+ cons_h = function (res, θ)
111
+ for i in 1 : num_cons
112
+ ReverseDiff. hessian! (res[i], fncs[i], θ)
113
+ end
62
114
end
63
115
end
64
116
else
@@ -83,25 +135,52 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
83
135
_f = (θ, args... ) -> first (f. f (θ, cache. p, args... ))
84
136
85
137
if f. grad === nothing
86
- grad = (res, θ, args... ) -> ReverseDiff. gradient! (res, x -> _f (x, args... ), θ)
138
+ if adtype. compile
139
+ _tape = ReverseDiff. GradientTape (_f, cache. u0)
140
+ tape = ReverseDiff. compile (_tape)
141
+ grad = function (res, θ, args... )
142
+ ReverseDiff. gradient! (res, tape, θ)
143
+ end
144
+ else
145
+ cfg = ReverseDiff. GradientConfig (cache. u0)
146
+ grad = (res, θ, args... ) -> ReverseDiff. gradient! (res, x -> _f (x, args... ), θ, cfg)
147
+ end
87
148
else
88
149
grad = (G, θ, args... ) -> f. grad (G, θ, cache. p, args... )
89
150
end
90
151
91
152
if f. hess === nothing
92
- hess = function (res, θ, args... )
93
- ReverseDiff. hessian! (res, x -> _f (x, args... ), θ)
153
+ if adtype. compile
154
+ T = ForwardDiff. Tag (OptimizationReverseDiffTag (),eltype (cache. u0))
155
+ xdual = ForwardDiff. Dual {typeof(T),eltype(cache.u0),length(cache.u0)} .(cache. u0, Ref (ForwardDiff. Partials ((ones (eltype (cache. u0), length (cache. u0))... ,))))
156
+ h_tape = ReverseDiff. GradientTape (_f, xdual)
157
+ htape = ReverseDiff. compile (h_tape)
158
+ function g (θ)
159
+ res1 = zeros (eltype (θ), length (θ))
160
+ ReverseDiff. gradient! (res1, htape, θ)
161
+ end
162
+ jaccfg = ForwardDiff. JacobianConfig (g, cache. u0, ForwardDiff. Chunk (cache. u0), T)
163
+ hess = function (res, θ, args... )
164
+ ForwardDiff. jacobian! (res, g, θ, jaccfg, Val {false} ())
165
+ end
166
+ else
167
+ hess = function (res, θ, args... )
168
+ ReverseDiff. hessian! (res, x -> _f (x, args... ), θ)
169
+ end
94
170
end
95
171
else
96
172
hess = (H, θ, args... ) -> f. hess (H, θ, cache. p, args... )
97
173
end
98
174
99
175
if f. hv === nothing
100
176
hv = function (H, θ, v, args... )
101
- _θ = ForwardDiff. Dual .(θ, v)
102
- res = similar (_θ)
103
- grad (res, _θ, args... )
104
- H .= getindex .(ForwardDiff. partials .(res), 1 )
177
+ # _θ = ForwardDiff.Dual.(θ, v)
178
+ # res = similar(_θ)
179
+ # grad(res, θ, args...)
180
+ # H .= getindex.(ForwardDiff.partials.(res), 1)
181
+ res = zeros (length (θ), length (θ))
182
+ hess (res, θ, args... )
183
+ H .= res * v
105
184
end
106
185
else
107
186
hv = f. hv
@@ -115,19 +194,43 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
115
194
end
116
195
117
196
if cons != = nothing && f. cons_j === nothing
118
- cjconfig = ReverseDiff. JacobianConfig (cache. u0)
119
- cons_j = function (J, θ)
120
- ReverseDiff. jacobian! (J, cons_oop, θ, cjconfig)
197
+ if adtype. compile
198
+ _jac_tape = ReverseDiff. JacobianTape (cons_oop, cache. u0)
199
+ jac_tape = ReverseDiff. compile (_jac_tape)
200
+ cons_j = function (J, θ)
201
+ ReverseDiff. jacobian! (J, jac_tape, θ)
202
+ end
203
+ else
204
+ cjconfig = ReverseDiff. JacobianConfig (cache. u0)
205
+ cons_j = function (J, θ)
206
+ ReverseDiff. jacobian! (J, cons_oop, θ, cjconfig)
207
+ end
121
208
end
122
209
else
123
210
cons_j = (J, θ) -> f. cons_j (J, θ, cache. p)
124
211
end
125
212
126
213
if cons != = nothing && f. cons_h === nothing
127
214
fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
128
- cons_h = function (res, θ)
129
- for i in 1 : num_cons
130
- ReverseDiff. hessian! (res[i], fncs[i], θ)
215
+ if adtype. compile
216
+ consh_tapes = ReverseDiff. GradientTape .(fncs, Ref (xdual))
217
+ conshtapes = ReverseDiff. compile .(consh_tapes)
218
+ function grad_cons (θ, htape)
219
+ res1 = zeros (eltype (θ), length (θ))
220
+ ReverseDiff. gradient! (res1, htape, θ)
221
+ end
222
+ gs = [x -> grad_cons (x, conshtapes[i]) for i in 1 : num_cons]
223
+ jaccfgs = [ForwardDiff. JacobianConfig (gs[i], cache. u0, ForwardDiff. Chunk (cache. u0), T) for i in 1 : num_cons]
224
+ cons_h = function (res, θ)
225
+ for i in 1 : num_cons
226
+ ForwardDiff. jacobian! (res[i], gs[i], θ, jaccfgs[i], Val {false} ())
227
+ end
228
+ end
229
+ else
230
+ cons_h = function (res, θ)
231
+ for i in 1 : num_cons
232
+ ReverseDiff. hessian! (res[i], fncs[i], θ)
233
+ end
131
234
end
132
235
end
133
236
else
0 commit comments