@@ -6,31 +6,36 @@ import Optimization.LinearAlgebra: I
6
6
import Optimization. ADTypes: AutoEnzyme
7
7
isdefined (Base, :get_extension ) ? (using Enzyme) : (using .. Enzyme)
8
8
9
+ @inline function firstapply (f, θ, p, args... )
10
+ first (f (θ, p, args... ))
11
+ end
12
+
9
13
function Optimization. instantiate_function (f:: OptimizationFunction{true} , x,
10
14
adtype:: AutoEnzyme , p,
11
15
num_cons = 0 )
12
- _f = (f, θ, args... ) -> first (f (θ, p, args... ))
13
16
14
17
if f. grad === nothing
15
18
function grad (res, θ, args... )
16
19
res .= zero (eltype (res))
17
20
Enzyme. autodiff (Enzyme. Reverse,
18
- Const (_f ),
21
+ Const (firstapply ),
19
22
Const (f. f),
20
23
Enzyme. Duplicated (θ, res),
24
+ Const (p),
21
25
args... )
22
26
end
23
27
else
24
28
grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
25
29
end
26
30
27
31
if f. hess === nothing
28
- function g (θ, bθ, _f, f , args... )
32
+ function g (θ, bθ, f, p , args... )
29
33
Enzyme. autodiff_deferred (Enzyme. Reverse,
30
- Const (_f ),
34
+ Const (firstapply ),
31
35
Const (f),
32
36
Enzyme. Duplicated (θ, bθ),
33
- args... )
37
+ Const (p),
38
+ args... ),
34
39
return nothing
35
40
end
36
41
function hess (res, θ, args... )
@@ -43,8 +48,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
43
48
g,
44
49
Enzyme. BatchDuplicated (θ, vdθ),
45
50
Enzyme. BatchDuplicated (bθ, vdbθ),
46
- Const (_f),
47
51
Const (f. f),
52
+ Const (p),
48
53
args... )
49
54
50
55
for i in eachindex (θ)
@@ -56,17 +61,19 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
56
61
end
57
62
58
63
if f. hv === nothing
59
- function f2 (x, _f, f , args... )
64
+ function f2 (x, f, p , args... )
60
65
dx = zeros (length (x))
61
- Enzyme. autodiff_deferred (Enzyme. Reverse, _f,
66
+ Enzyme. autodiff_deferred (Enzyme. Reverse,
67
+ firstapply,
62
68
f,
63
69
Enzyme. Duplicated (x, dx),
70
+ Const (p),
64
71
args... )
65
72
return dx
66
73
end
67
74
hv = function (H, θ, v, args... )
68
75
H .= Enzyme. autodiff (Enzyme. Forward, f2, DuplicatedNoNeed, Duplicated (θ, v),
69
- Const (_f), Const (f. f),
76
+ Const (_f), Const (f. f), Const (p),
70
77
args... )[1 ]
71
78
end
72
79
else
@@ -141,25 +148,27 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
141
148
cache:: Optimization.ReInitCache ,
142
149
adtype:: AutoEnzyme ,
143
150
num_cons = 0 )
144
- _f = (f, θ, args ... ) -> first ( f (θ, cache. p, args ... ))
151
+ p = cache. p
145
152
146
153
if f. grad === nothing
147
154
function grad (res, θ, args... )
148
155
res .= zero (eltype (res))
149
156
Enzyme. autodiff (Enzyme. Reverse,
150
- Const (_f ),
157
+ Const (firstapply ),
151
158
Const (f. f),
152
159
Enzyme. Duplicated (θ, res),
160
+ Const (p),
153
161
args... )
154
162
end
155
163
else
156
164
grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
157
165
end
158
166
159
167
if f. hess === nothing
160
- function g (θ, bθ, _f, f , args... )
161
- Enzyme. autodiff_deferred (Enzyme. Reverse, Const (_f ), Const (f),
168
+ function g (θ, bθ, f, p , args... )
169
+ Enzyme. autodiff_deferred (Enzyme. Reverse, Const (firstapply ), Const (f),
162
170
Enzyme. Duplicated (θ, bθ),
171
+ Const (p),
163
172
args... )
164
173
return nothing
165
174
end
@@ -173,8 +182,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
173
182
g,
174
183
Enzyme. BatchDuplicated (θ, vdθ),
175
184
Enzyme. BatchDuplicated (bθ, vdbθ),
176
- Const (_f),
177
185
Const (f. f),
186
+ Const (p),
178
187
args... )
179
188
180
189
for i in eachindex (θ)
@@ -186,17 +195,18 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
186
195
end
187
196
188
197
if f. hv === nothing
189
- function f2 (x, _f, f , args... )
198
+ function f2 (x, f, p , args... )
190
199
dx = zeros (length (x))
191
- Enzyme. autodiff_deferred (Enzyme. Reverse, _f ,
200
+ Enzyme. autodiff_deferred (Enzyme. Reverse, firstapply ,
192
201
f,
193
202
Enzyme. Duplicated (x, dx),
203
+ Const (p),
194
204
args... )
195
205
return dx
196
206
end
197
207
hv = function (H, θ, v, args... )
198
208
H .= Enzyme. autodiff (Enzyme. Forward, f2, DuplicatedNoNeed, Duplicated (θ, v),
199
- Const (_f ), Const (f . f ),
209
+ Const (f . f ), Const (p ),
200
210
args... )[1 ]
201
211
end
202
212
else
@@ -206,7 +216,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
206
216
if f. cons === nothing
207
217
cons = nothing
208
218
else
209
- cons = (res, θ) -> (f. cons (res, θ, cache . p); return nothing )
219
+ cons = (res, θ) -> (f. cons (res, θ, p); return nothing )
210
220
cons_oop = (x) -> (_res = zeros (eltype (x), num_cons); cons (_res, x); _res)
211
221
end
212
222
@@ -219,14 +229,14 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
219
229
end
220
230
end
221
231
else
222
- cons_j = (J, θ) -> f. cons_j (J, θ, cache . p)
232
+ cons_j = (J, θ) -> f. cons_j (J, θ, p)
223
233
end
224
234
225
235
if cons != = nothing && f. cons_h === nothing
226
236
fncs = map (1 : num_cons) do i
227
237
function (x)
228
238
res = zeros (eltype (x), num_cons)
229
- f. cons (res, x, cache . p)
239
+ f. cons (res, x, p)
230
240
return res[i]
231
241
end
232
242
end
@@ -241,7 +251,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
241
251
end
242
252
end
243
253
else
244
- cons_h = (res, θ) -> f. cons_h (res, θ, cache . p)
254
+ cons_h = (res, θ) -> f. cons_h (res, θ, p)
245
255
end
246
256
247
257
return OptimizationFunction {true} (f. f, adtype; grad = grad, hess = hess, hv = hv,
0 commit comments