Skip to content

Commit c9af69b

Browse files
committed
Move functions into struct
1 parent d1220d0 commit c9af69b

File tree

1 file changed

+77
-51
lines changed

1 file changed

+77
-51
lines changed

src/enzyme.jl

Lines changed: 77 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,55 @@
1+
function _gradient!(dx, f, x)
2+
Enzyme.make_zero!(dx)
3+
res = Enzyme.autodiff(
4+
Enzyme.set_runtime_activity(Enzyme.Reverse),
5+
f,
6+
Enzyme.Active,
7+
Enzyme.Duplicated(x, dx),
8+
)
9+
return nothing
10+
end
11+
12+
function _hvp!(res, f, x, v)
13+
Enzyme.autodiff(
14+
Enzyme.set_runtime_activity(Enzyme.Forward),
15+
_gradient!,
16+
res,
17+
Enzyme.Const(f),
18+
Enzyme.Duplicated(x, v),
19+
)
20+
return nothing
21+
end
22+
23+
function _gradient!(dx, ℓ, x, y, obj_weight, cx)
24+
Enzyme.make_zero!(dx)
25+
dcx = Enzyme.make_zero(cx)
26+
res = Enzyme.autodiff(
27+
Enzyme.set_runtime_activity(Enzyme.Reverse),
28+
ℓ,
29+
Enzyme.Active,
30+
Enzyme.Duplicated(x, dx),
31+
Enzyme.Const(y),
32+
Enzyme.Const(obj_weight),
33+
Enzyme.Duplicated(cx, dcx),
34+
)
35+
return nothing
36+
end
37+
38+
function _hvp!(res, ℓ, x, v, y, obj_weight, cx)
39+
dcx = Enzyme.make_zero(cx)
40+
Enzyme.autodiff(
41+
Enzyme.set_runtime_activity(Enzyme.Forward),
42+
_gradient!,
43+
res,
44+
Enzyme.Const(ℓ),
45+
Enzyme.Duplicated(x, v),
46+
Enzyme.Const(y),
47+
Enzyme.Const(obj_weight),
48+
Enzyme.Duplicated(cx, dcx),
49+
)
50+
return nothing
51+
end
52+
153
struct EnzymeReverseADGradient <: InPlaceADbackend end
254

355
function EnzymeReverseADGradient(
@@ -23,9 +75,10 @@ function EnzymeReverseADJacobian(
2375
return EnzymeReverseADJacobian()
2476
end
2577

26-
struct EnzymeReverseADHessian{T} <: ADBackend
78+
struct EnzymeReverseADHessian{T,F} <: ADBackend
2779
seed::Vector{T}
2880
Hv::Vector{T}
81+
f::F
2982
end
3083

3184
function EnzymeReverseADHessian(
@@ -41,11 +94,12 @@ function EnzymeReverseADHessian(
4194

4295
seed = zeros(T, nvar)
4396
Hv = zeros(T, nvar)
44-
return EnzymeReverseADHessian(seed, Hv)
97+
return EnzymeReverseADHessian(seed, Hv, f)
4598
end
4699

47-
struct EnzymeReverseADHvprod{T} <: InPlaceADbackend
100+
struct EnzymeReverseADHvprod{T,F} <: InPlaceADbackend
48101
grad::Vector{T}
102+
f::F
49103
end
50104

51105
function EnzymeReverseADHvprod(
@@ -57,7 +111,7 @@ function EnzymeReverseADHvprod(
57111
kwargs...,
58112
) where {T}
59113
grad = zeros(T, nvar)
60-
return EnzymeReverseADHvprod(grad)
114+
return EnzymeReverseADHvprod(grad,f)
61115
end
62116

63117
struct EnzymeReverseADJprod{T} <: InPlaceADbackend
@@ -153,7 +207,7 @@ function SparseEnzymeADJacobian(
153207
)
154208
end
155209

156-
struct SparseEnzymeADHessian{R, C, S, L} <: ADBackend
210+
struct SparseEnzymeADHessian{R, C, S, L, F} <: ADBackend
157211
nvar::Int
158212
rowval::Vector{Int}
159213
colptr::Vector{Int}
@@ -166,6 +220,7 @@ struct SparseEnzymeADHessian{R, C, S, L} <: ADBackend
166220
y::Vector{R}
167221
grad::Vector{R}
168222
cx::Vector{R}
223+
f::F
169224
::L
170225
end
171226

@@ -216,12 +271,11 @@ function SparseEnzymeADHessian(
216271
cx = similar(x0, ncon)
217272
grad = similar(x0)
218273
function ℓ(x, y, obj_weight, cx)
219-
# res = obj_weight * f(x)
220-
res = f(x)
221-
# if ncon != 0
222-
# c!(cx, x)
223-
# res += sum(cx[i] * y[i] for i = 1:ncon)
224-
# end
274+
res = obj_weight * f(x)
275+
if ncon != 0
276+
c!(cx, x)
277+
res += sum(cx[i] * y[i] for i = 1:ncon)
278+
end
225279
return res
226280
end
227281

@@ -238,6 +292,7 @@ function SparseEnzymeADHessian(
238292
y,
239293
grad,
240294
cx,
295+
f,
241296
ℓ,
242297
)
243298
end
@@ -248,7 +303,7 @@ end
248303

249304
function ADNLPModels.gradient(::EnzymeReverseADGradient, f, x)
250305
g = similar(x)
251-
Enzyme.gradient!(Enzyme.Reverse, g, Enzyme.Const(f), x)
306+
Enzyme.autodiff(set_runtime_activity(Enzyme.Reverse), Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, g))
252307
return g
253308
end
254309

@@ -269,14 +324,15 @@ end
269324
b.seed[i] = one(T)
270325
# Enzyme.hvp!(b.Hv, f, x, b.seed)
271326
grad = make_zero(x)
272-
Enzyme.autodiff(
273-
Enzyme.Forward,
274-
Enzyme.Const(Enzyme.gradient!),
275-
Enzyme.Const(Enzyme.Reverse),
276-
Enzyme.DuplicatedNoNeed(grad, b.Hv),
277-
Enzyme.Const(f),
278-
Enzyme.Duplicated(x, b.seed),
279-
)
327+
_hvp!(DuplicatedNoNeed(grad, b.Hv), b.f, x, b.seed)
328+
# Enzyme.autodiff(
329+
# Enzyme.Forward,
330+
# Enzyme.Const(Enzyme.gradient!),
331+
# Enzyme.Const(Enzyme.Reverse),
332+
# Enzyme.DuplicatedNoNeed(grad, b.Hv),
333+
# Enzyme.Const(f),
334+
# Enzyme.Duplicated(x, b.seed),
335+
# )
280336
view(hess, :, i) .= b.Hv
281337
b.seed[i] = zero(T)
282338
end
@@ -339,7 +395,7 @@ end
339395
Enzyme.Const(Enzyme.gradient!),
340396
Enzyme.Const(Enzyme.Reverse),
341397
Enzyme.DuplicatedNoNeed(b.grad, Hv),
342-
Enzyme.Const(f),
398+
Enzyme.Const(b.f),
343399
Enzyme.Duplicated(x, v),
344400
)
345401
return Hv
@@ -471,36 +527,6 @@ end
471527
b.v[col] = 1
472528
end
473529

474-
function _gradient!(dx, ℓ, x, y, obj_weight, cx)
475-
Enzyme.make_zero!(dx)
476-
dcx = Enzyme.make_zero(cx)
477-
res = Enzyme.autodiff(
478-
Enzyme.set_runtime_activity(Enzyme.Reverse),
479-
ℓ,
480-
Enzyme.Active,
481-
Enzyme.Duplicated(x, dx),
482-
Enzyme.Const(y),
483-
Enzyme.Const(obj_weight),
484-
Enzyme.Duplicated(cx, dcx),
485-
)
486-
return nothing
487-
end
488-
489-
function _hvp!(res, ℓ, x, v, y, obj_weight, cx)
490-
dcx = Enzyme.make_zero(cx)
491-
Enzyme.autodiff(
492-
Enzyme.set_runtime_activity(Enzyme.Forward),
493-
_gradient!,
494-
res,
495-
Enzyme.Const(ℓ),
496-
Enzyme.Duplicated(x, v),
497-
Enzyme.Const(y),
498-
Enzyme.Const(obj_weight),
499-
Enzyme.Duplicated(cx, dcx),
500-
)
501-
return nothing
502-
end
503-
504530
_hvp!(
505531
Enzyme.DuplicatedNoNeed(b.grad, b.compressed_hessian_icol),
506532
b.ℓ,

0 commit comments

Comments
 (0)