@@ -6,31 +6,38 @@ import Optimization.LinearAlgebra: I
6
6
import Optimization. ADTypes: AutoEnzyme
7
7
isdefined (Base, :get_extension ) ? (using Enzyme) : (using .. Enzyme)
8
8
9
+ @inline function firstapply (f, θ, p, args... )
10
+ first (f (θ, p, args... ))
11
+ end
12
+
9
13
function Optimization. instantiate_function (f:: OptimizationFunction{true} , x,
10
14
adtype:: AutoEnzyme , p,
11
15
num_cons = 0 )
12
- _f = (f, θ, args... ) -> first (f (θ, p, args... ))
13
16
14
17
if f. grad === nothing
15
18
function grad (res, θ, args... )
16
19
res .= zero (eltype (res))
17
20
Enzyme. autodiff (Enzyme. Reverse,
18
- Const (_f),
21
+ Const (firstapply),
22
+ Active,
19
23
Const (f. f),
20
24
Enzyme. Duplicated (θ, res),
25
+ Const (p),
21
26
args... )
22
27
end
23
28
else
24
29
grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
25
30
end
26
31
27
32
if f. hess === nothing
28
- function g (θ, bθ, _f, f , args... )
33
+ function g (θ, bθ, f, p , args... )
29
34
Enzyme. autodiff_deferred (Enzyme. Reverse,
30
- Const (_f),
35
+ Const (firstapply),
36
+ Active,
31
37
Const (f),
32
38
Enzyme. Duplicated (θ, bθ),
33
- args... )
39
+ Const (p),
40
+ args... ),
34
41
return nothing
35
42
end
36
43
function hess (res, θ, args... )
@@ -43,8 +50,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
43
50
g,
44
51
Enzyme. BatchDuplicated (θ, vdθ),
45
52
Enzyme. BatchDuplicated (bθ, vdbθ),
46
- Const (_f),
47
53
Const (f. f),
54
+ Const (p),
48
55
args... )
49
56
50
57
for i in eachindex (θ)
@@ -56,17 +63,20 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
56
63
end
57
64
58
65
if f. hv === nothing
59
- function f2 (x, _f, f , args... )
66
+ function f2 (x, f, p , args... )
60
67
dx = zeros (length (x))
61
- Enzyme. autodiff_deferred (Enzyme. Reverse, _f,
68
+ Enzyme. autodiff_deferred (Enzyme. Reverse,
69
+ firstapply,
70
+ Active,
62
71
f,
63
72
Enzyme. Duplicated (x, dx),
73
+ Const (p),
64
74
args... )
65
75
return dx
66
76
end
67
77
hv = function (H, θ, v, args... )
68
78
H .= Enzyme. autodiff (Enzyme. Forward, f2, DuplicatedNoNeed, Duplicated (θ, v),
69
- Const (_f), Const (f. f),
79
+ Const (_f), Const (f. f), Const (p),
70
80
args... )[1 ]
71
81
end
72
82
else
@@ -141,25 +151,28 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
141
151
cache:: Optimization.ReInitCache ,
142
152
adtype:: AutoEnzyme ,
143
153
num_cons = 0 )
144
- _f = (f, θ, args ... ) -> first ( f (θ, cache. p, args ... ))
154
+ p = cache. p
145
155
146
156
if f. grad === nothing
147
157
function grad (res, θ, args... )
148
158
res .= zero (eltype (res))
149
159
Enzyme. autodiff (Enzyme. Reverse,
150
- Const (_f),
160
+ Const (firstapply),
161
+ Active,
151
162
Const (f. f),
152
163
Enzyme. Duplicated (θ, res),
164
+ Const (p),
153
165
args... )
154
166
end
155
167
else
156
168
grad = (G, θ, args... ) -> f. grad (G, θ, p, args... )
157
169
end
158
170
159
171
if f. hess === nothing
160
- function g (θ, bθ, _f, f , args... )
161
- Enzyme. autodiff_deferred (Enzyme. Reverse, Const (_f) , Const (f),
172
+ function g (θ, bθ, f, p , args... )
173
+ Enzyme. autodiff_deferred (Enzyme. Reverse, Const (firstapply), Active , Const (f),
162
174
Enzyme. Duplicated (θ, bθ),
175
+ Const (p),
163
176
args... )
164
177
return nothing
165
178
end
@@ -173,8 +186,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
173
186
g,
174
187
Enzyme. BatchDuplicated (θ, vdθ),
175
188
Enzyme. BatchDuplicated (bθ, vdbθ),
176
- Const (_f),
177
189
Const (f. f),
190
+ Const (p),
178
191
args... )
179
192
180
193
for i in eachindex (θ)
@@ -186,17 +199,18 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
186
199
end
187
200
188
201
if f. hv === nothing
189
- function f2 (x, _f, f , args... )
202
+ function f2 (x, f, p , args... )
190
203
dx = zeros (length (x))
191
- Enzyme. autodiff_deferred (Enzyme. Reverse, _f ,
204
+ Enzyme. autodiff_deferred (Enzyme. Reverse, firstapply, Active ,
192
205
f,
193
206
Enzyme. Duplicated (x, dx),
207
+ Const (p),
194
208
args... )
195
209
return dx
196
210
end
197
211
hv = function (H, θ, v, args... )
198
212
H .= Enzyme. autodiff (Enzyme. Forward, f2, DuplicatedNoNeed, Duplicated (θ, v),
199
- Const (_f ), Const (f . f ),
213
+ Const (f . f ), Const (p ),
200
214
args... )[1 ]
201
215
end
202
216
else
@@ -206,7 +220,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
206
220
if f. cons === nothing
207
221
cons = nothing
208
222
else
209
- cons = (res, θ) -> (f. cons (res, θ, cache . p); return nothing )
223
+ cons = (res, θ) -> (f. cons (res, θ, p); return nothing )
210
224
cons_oop = (x) -> (_res = zeros (eltype (x), num_cons); cons (_res, x); _res)
211
225
end
212
226
@@ -219,14 +233,14 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
219
233
end
220
234
end
221
235
else
222
- cons_j = (J, θ) -> f. cons_j (J, θ, cache . p)
236
+ cons_j = (J, θ) -> f. cons_j (J, θ, p)
223
237
end
224
238
225
239
if cons != = nothing && f. cons_h === nothing
226
240
fncs = map (1 : num_cons) do i
227
241
function (x)
228
242
res = zeros (eltype (x), num_cons)
229
- f. cons (res, x, cache . p)
243
+ f. cons (res, x, p)
230
244
return res[i]
231
245
end
232
246
end
@@ -241,7 +255,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
241
255
end
242
256
end
243
257
else
244
- cons_h = (res, θ) -> f. cons_h (res, θ, cache . p)
258
+ cons_h = (res, θ) -> f. cons_h (res, θ, p)
245
259
end
246
260
247
261
return OptimizationFunction {true} (f. f, adtype; grad = grad, hess = hess, hv = hv,
0 commit comments