Skip to content

Commit 273b959

Browse files
Merge pull request #586 from wsmoses/master
Fix type instability in Enzyme extension
2 parents 0db6844 + 45bfe0c commit 273b959

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,38 @@ import Optimization.LinearAlgebra: I
66
import Optimization.ADTypes: AutoEnzyme
77
isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme)
88

9+
@inline function firstapply(f, θ, p, args...)
10+
first(f(θ, p, args...))
11+
end
12+
913
function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
1014
adtype::AutoEnzyme, p,
1115
num_cons = 0)
12-
_f = (f, θ, args...) -> first(f(θ, p, args...))
1316

1417
if f.grad === nothing
1518
function grad(res, θ, args...)
1619
res .= zero(eltype(res))
1720
Enzyme.autodiff(Enzyme.Reverse,
18-
Const(_f),
21+
Const(firstapply),
22+
Active,
1923
Const(f.f),
2024
Enzyme.Duplicated(θ, res),
25+
Const(p),
2126
args...)
2227
end
2328
else
2429
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
2530
end
2631

2732
if f.hess === nothing
28-
function g(θ, bθ, _f, f, args...)
33+
function g(θ, bθ, f, p, args...)
2934
Enzyme.autodiff_deferred(Enzyme.Reverse,
30-
Const(_f),
35+
Const(firstapply),
36+
Active,
3137
Const(f),
3238
Enzyme.Duplicated(θ, bθ),
33-
args...)
39+
Const(p),
40+
args...),
3441
return nothing
3542
end
3643
function hess(res, θ, args...)
@@ -43,8 +50,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
4350
g,
4451
Enzyme.BatchDuplicated(θ, vdθ),
4552
Enzyme.BatchDuplicated(bθ, vdbθ),
46-
Const(_f),
4753
Const(f.f),
54+
Const(p),
4855
args...)
4956

5057
for i in eachindex(θ)
@@ -56,17 +63,20 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
5663
end
5764

5865
if f.hv === nothing
59-
function f2(x, _f, f, args...)
66+
function f2(x, f, p, args...)
6067
dx = zeros(length(x))
61-
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,
68+
Enzyme.autodiff_deferred(Enzyme.Reverse,
69+
firstapply,
70+
Active,
6271
f,
6372
Enzyme.Duplicated(x, dx),
73+
Const(p),
6474
args...)
6575
return dx
6676
end
6777
hv = function (H, θ, v, args...)
6878
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
69-
Const(_f), Const(f.f),
79+
Const(_f), Const(f.f), Const(p),
7080
args...)[1]
7181
end
7282
else
@@ -141,25 +151,28 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
141151
cache::Optimization.ReInitCache,
142152
adtype::AutoEnzyme,
143153
num_cons = 0)
144-
_f = (f, θ, args...) -> first(f(θ, cache.p, args...))
154+
p = cache.p
145155

146156
if f.grad === nothing
147157
function grad(res, θ, args...)
148158
res .= zero(eltype(res))
149159
Enzyme.autodiff(Enzyme.Reverse,
150-
Const(_f),
160+
Const(firstapply),
161+
Active,
151162
Const(f.f),
152163
Enzyme.Duplicated(θ, res),
164+
Const(p),
153165
args...)
154166
end
155167
else
156168
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
157169
end
158170

159171
if f.hess === nothing
160-
function g(θ, bθ, _f, f, args...)
161-
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(_f), Const(f),
172+
function g(θ, bθ, f, p, args...)
173+
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(firstapply), Active, Const(f),
162174
Enzyme.Duplicated(θ, bθ),
175+
Const(p),
163176
args...)
164177
return nothing
165178
end
@@ -173,8 +186,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
173186
g,
174187
Enzyme.BatchDuplicated(θ, vdθ),
175188
Enzyme.BatchDuplicated(bθ, vdbθ),
176-
Const(_f),
177189
Const(f.f),
190+
Const(p),
178191
args...)
179192

180193
for i in eachindex(θ)
@@ -186,17 +199,18 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
186199
end
187200

188201
if f.hv === nothing
189-
function f2(x, _f, f, args...)
202+
function f2(x, f, p, args...)
190203
dx = zeros(length(x))
191-
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,
204+
Enzyme.autodiff_deferred(Enzyme.Reverse, firstapply, Active,
192205
f,
193206
Enzyme.Duplicated(x, dx),
207+
Const(p),
194208
args...)
195209
return dx
196210
end
197211
hv = function (H, θ, v, args...)
198212
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
199-
Const(_f), Const(f.f),
213+
Const(f.f), Const(p),
200214
args...)[1]
201215
end
202216
else
@@ -206,7 +220,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
206220
if f.cons === nothing
207221
cons = nothing
208222
else
209-
cons = (res, θ) -> (f.cons(res, θ, cache.p); return nothing)
223+
cons = (res, θ) -> (f.cons(res, θ, p); return nothing)
210224
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
211225
end
212226

@@ -219,14 +233,14 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
219233
end
220234
end
221235
else
222-
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
236+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
223237
end
224238

225239
if cons !== nothing && f.cons_h === nothing
226240
fncs = map(1:num_cons) do i
227241
function (x)
228242
res = zeros(eltype(x), num_cons)
229-
f.cons(res, x, cache.p)
243+
f.cons(res, x, p)
230244
return res[i]
231245
end
232246
end
@@ -241,7 +255,7 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
241255
end
242256
end
243257
else
244-
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
258+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
245259
end
246260

247261
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,

0 commit comments

Comments
 (0)