Skip to content

Commit d0c6c15

Browse files
committed
fix tests again
1 parent f23edf1 commit d0c6c15

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

src/enzyme.jl

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -148,20 +148,18 @@ end
148148
@init begin
149149
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
150150

151-
import Enzyme: Const, Reverse, Forward, Duplicated, DuplicatedNoNeed
152-
153151
function ADNLPModels.gradient(::EnzymeReverseADGradient, f, x)
154152
g = similar(x)
155-
Enzyme.gradient!(Reverse, g, Const(f), x)
153+
Enzyme.gradient!(Enzyme.Reverse, g, Enzyme.Const(f), x)
156154
return g
157155
end
158156

159157
function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x)
160-
Enzyme.autodiff(Reverse, Const(f), Active, Duplicated(x, g))
158+
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, g))
161159
return g
162160
end
163161

164-
jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Reverse, f, x)
162+
jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Enzyme.Reverse, f, x)
165163

166164
function hessian(::EnzymeReverseADHessian, f, x)
167165
seed = similar(x)
@@ -170,32 +168,32 @@ end
170168
tmp = similar(x)
171169
for i in 1:length(x)
172170
seed[i] = one(eltype(seed))
173-
Enzyme.hvp!(tmp, Const(f), x, seed)
171+
Enzyme.hvp!(tmp, Enzyme.Const(f), x, seed)
174172
hess[:, i] .= tmp
175173
seed[i] = zero(eltype(seed))
176174
end
177175
return hess
178176
end
179177

180178
function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val)
181-
Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(b.x, Jv), Duplicated(x, v))
179+
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jv), Enzyme.Duplicated(x, v))
182180
return Jv
183181
end
184182

185183
function Jtprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
186-
Enzyme.autodiff(Reverse, Const(c!), Duplicated(b.x, Jtv), Duplicated(x, v))
184+
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
187185
return Jtv
188186
end
189187

190188
function Hvprod!(b::EnzymeReverseADHvprod, Hv, x, v, f, args...)
191189
# What to do with args?
192190
Enzyme.autodiff(
193-
Forward,
194-
Const(Enzyme.gradient!),
195-
Const(Reverse),
196-
DuplicatedNoNeed(b.grad, Hv),
197-
Const(f),
198-
Duplicated(x, v),
191+
Enzyme.Forward,
192+
Enzyme.Const(Enzyme.gradient!),
193+
Enzyme.Const(Enzyme.Reverse),
194+
Enzyme.DuplicatedNoNeed(b.grad, Hv),
195+
Enzyme.Const(f),
196+
Enzyme.Duplicated(x, v),
199197
)
200198
return Hv
201199
end
@@ -211,13 +209,13 @@ end
211209
obj_weight::Real = one(eltype(x)),
212210
)
213211
Enzyme.autodiff(
214-
Forward,
215-
Const(Enzyme.gradient!),
216-
Const(Reverse),
217-
DuplicatedNoNeed(b.grad, Hv),
218-
Const(ℓ),
219-
Duplicated(x, v),
220-
Const(y),
212+
Enzyme.Forward,
213+
Enzyme.Const(Enzyme.gradient!),
214+
Enzyme.Const(Enzyme.Reverse),
215+
Enzyme.DuplicatedNoNeed(b.grad, Hv),
216+
Enzyme.Const(ℓ),
217+
Enzyme.Duplicated(x, v),
218+
Enzyme.Const(y),
221219
)
222220

223221
return Hv
@@ -233,13 +231,13 @@ end
233231
obj_weight::Real = one(eltype(x)),
234232
)
235233
Enzyme.autodiff(
236-
Forward,
237-
Const(Enzyme.gradient!),
238-
Const(Reverse),
239-
DuplicatedNoNeed(b.grad, Hv),
240-
Const(f),
241-
Duplicated(x, v),
242-
Const(y),
234+
Enzyme.Forward,
235+
Enzyme.Const(Enzyme.gradient!),
236+
Enzyme.Const(Enzyme.Reverse),
237+
Enzyme.DuplicatedNoNeed(b.grad, Hv),
238+
Enzyme.Const(f),
239+
Enzyme.Duplicated(x, v),
240+
Enzyme.Const(y),
243241
)
244242
return Hv
245243
end
@@ -264,7 +262,7 @@ end
264262

265263
# b.compressed_jacobian is just a vector Jv here
266264
# We don't use the vector mode
267-
Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(b.buffer, b.compressed_jacobian), Duplicated(x, b.v))
265+
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.buffer, b.compressed_jacobian), Enzyme.Duplicated(x, b.v))
268266

269267
# Update the columns of the Jacobian that have the color `icol`
270268
decompress_single_color!(A, b.compressed_jacobian, icol, b.result_coloring)

0 commit comments

Comments
 (0)