Skip to content

Commit 36646c4

Browse files
committed
Fix type instability in Enzyme extension
1 parent 0db6844 commit 36646c4

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,36 @@ import Optimization.LinearAlgebra: I
66
import Optimization.ADTypes: AutoEnzyme
77
isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme)
88

9+
@inline function firstapply(f, θ, p, args...)
10+
first(f(θ, p, args...))
11+
end
12+
913
function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
1014
adtype::AutoEnzyme, p,
1115
num_cons = 0)
12-
_f = (f, θ, args...) -> first(f(θ, p, args...))
1316

1417
if f.grad === nothing
1518
function grad(res, θ, args...)
1619
res .= zero(eltype(res))
1720
Enzyme.autodiff(Enzyme.Reverse,
18-
Const(_f),
21+
Const(firstapply),
1922
Const(f.f),
2023
Enzyme.Duplicated(θ, res),
24+
Const(p),
2125
args...)
2226
end
2327
else
2428
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
2529
end
2630

2731
if f.hess === nothing
28-
function g(θ, bθ, _f, f, args...)
32+
function g(θ, bθ, f, p, args...)
2933
Enzyme.autodiff_deferred(Enzyme.Reverse,
30-
Const(_f),
34+
Const(firstapply),
3135
Const(f),
3236
Enzyme.Duplicated(θ, bθ),
33-
args...)
37+
Const(p),
38+
args...),
3439
return nothing
3540
end
3641
function hess(res, θ, args...)
@@ -43,8 +48,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
4348
g,
4449
Enzyme.BatchDuplicated(θ, vdθ),
4550
Enzyme.BatchDuplicated(bθ, vdbθ),
46-
Const(_f),
4751
Const(f.f),
52+
Const(p),
4853
args...)
4954

5055
for i in eachindex(θ)
@@ -56,17 +61,19 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
5661
end
5762

5863
if f.hv === nothing
59-
function f2(x, _f, f, args...)
64+
function f2(x, f, p, args...)
6065
dx = zeros(length(x))
61-
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,
66+
Enzyme.autodiff_deferred(Enzyme.Reverse,
67+
firstapply,
6268
f,
6369
Enzyme.Duplicated(x, dx),
70+
Const(p),
6471
args...)
6572
return dx
6673
end
6774
hv = function (H, θ, v, args...)
6875
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
69-
Const(_f), Const(f.f),
76+
Const(_f), Const(f.f), Const(p),
7077
args...)[1]
7178
end
7279
else
@@ -141,25 +148,27 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
141148
cache::Optimization.ReInitCache,
142149
adtype::AutoEnzyme,
143150
num_cons = 0)
144-
_f = (f, θ, args...) -> first(f(θ, cache.p, args...))
151+
p = cache.p
145152

146153
if f.grad === nothing
147154
function grad(res, θ, args...)
148155
res .= zero(eltype(res))
149156
Enzyme.autodiff(Enzyme.Reverse,
150-
Const(_f),
157+
Const(firstapply),
151158
Const(f.f),
152159
Enzyme.Duplicated(θ, res),
160+
Const(p),
153161
args...)
154162
end
155163
else
156164
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
157165
end
158166

159167
if f.hess === nothing
160-
function g(θ, bθ, _f, f, args...)
161-
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(_f), Const(f),
168+
function g(θ, bθ, f, p, args...)
169+
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(firstapply), Const(f),
162170
Enzyme.Duplicated(θ, bθ),
171+
Const(p),
163172
args...)
164173
return nothing
165174
end
@@ -173,8 +182,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
173182
g,
174183
Enzyme.BatchDuplicated(θ, vdθ),
175184
Enzyme.BatchDuplicated(bθ, vdbθ),
176-
Const(_f),
177185
Const(f.f),
186+
Const(p),
178187
args...)
179188

180189
for i in eachindex(θ)
@@ -186,17 +195,18 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
186195
end
187196

188197
if f.hv === nothing
189-
function f2(x, _f, f, args...)
198+
function f2(x, f, p, args...)
190199
dx = zeros(length(x))
191-
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,
200+
Enzyme.autodiff_deferred(Enzyme.Reverse, firstapply,
192201
f,
193202
Enzyme.Duplicated(x, dx),
203+
Const(p),
194204
args...)
195205
return dx
196206
end
197207
hv = function (H, θ, v, args...)
198208
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
199-
Const(_f), Const(f.f),
209+
Const(f.f), Const(p),
200210
args...)[1]
201211
end
202212
else
@@ -206,7 +216,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
206216
if f.cons === nothing
207217
cons = nothing
208218
else
209-
cons = (res, θ) -> (f.cons(res, θ, cache.p); return nothing)
219+
cons = (res, θ) -> (f.cons(res, θ, p); return nothing)
210220
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
211221
end
212222

@@ -219,14 +229,14 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
219229
end
220230
end
221231
else
222-
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
232+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
223233
end
224234

225235
if cons !== nothing && f.cons_h === nothing
226236
fncs = map(1:num_cons) do i
227237
function (x)
228238
res = zeros(eltype(x), num_cons)
229-
f.cons(res, x, cache.p)
239+
f.cons(res, x, p)
230240
return res[i]
231241
end
232242
end
@@ -241,7 +251,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
241251
end
242252
end
243253
else
244-
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
254+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
245255
end
246256

247257
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,

0 commit comments

Comments
 (0)