Skip to content

Commit b4d1765

Browse files
committed
Enzyme WIP
1 parent e73cb3e commit b4d1765

File tree

1 file changed

+158
-9
lines changed

1 file changed

+158
-9
lines changed

src/enzyme.jl

Lines changed: 158 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,170 @@
1-
struct EnzymeADGradient <: ADNLPModels.ADBackend end
1+
struct EnzymeReverseADJacobian <: ADBackend end
2+
struct EnzymeReverseADHessian <: ADBackend end
23

3-
function EnzymeADGradient(
4+
struct EnzymeReverseADGradient <: ADNLPModels.ADBackend end
5+
6+
function EnzymeReverseADGradient(
47
nvar::Integer,
58
f,
69
ncon::Integer = 0,
710
c::Function = (args...) -> [];
811
x0::AbstractVector = rand(nvar),
912
kwargs...,
1013
)
11-
return EnzymeADGradient()
14+
return EnzymeReverseADGradient()
15+
end
16+
17+
function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x)
18+
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
19+
return g
20+
end
21+
22+
function EnzymeReverseADJacobian(
23+
nvar::Integer,
24+
f,
25+
ncon::Integer = 0,
26+
c::Function = (args...) -> [];
27+
kwargs...,
28+
)
29+
return EnzymeReverseADJacobian()
30+
end
31+
32+
jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Enzyme.Reverse, f, x)
33+
34+
function EnzymeReverseADHessian(
35+
nvar::Integer,
36+
37+
f,
38+
ncon::Integer = 0,
39+
c::Function = (args...) -> [];
40+
kwargs...,
41+
)
42+
@assert nvar > 0
43+
nnzh = nvar * (nvar + 1) / 2
44+
return EnzymeReverseADHessian()
1245
end
1346

14-
@init begin
15-
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
16-
function ADNLPModels.gradient!(::EnzymeADGradient, g, f, x)
17-
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
18-
return g
19-
end
47+
function hessian(::EnzymeReverseADHessian, f, x)
48+
seed = similar(x)
49+
hess = zeros(eltype(x), length(x), length(x))
50+
fill!(seed, zero(x))
51+
for i in 1:length(x)
52+
seed[i] = one(x)
53+
Enzyme.hvp!(view(hess, i, :), f, x, seed)
54+
seed[i] = zero(x)
2055
end
56+
return hess
57+
end
58+
59+
struct EnzymeReverseADJprod <: InPlaceADBackend
60+
x::Vector{Float64}
61+
end
62+
63+
function EnzymeReverseADJprod(
64+
nvar::Integer,
65+
f,
66+
ncon::Integer = 0,
67+
c::Function = (args...) -> [];
68+
kwargs...,
69+
)
70+
x = zeros(nvar)
71+
return EnzymeReverseADJprod(x)
72+
end
73+
74+
function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val)
75+
Enzyme.autodiff(Enzyme.Forward, c!, Duplicated(b.x, Jv), Enzyme.Duplicated(x, v))
76+
return Jv
77+
end
78+
79+
struct EnzymeReverseADJtprod <: InPlaceADBackend
80+
x::Vector{Float64}
81+
end
82+
83+
function EnzymeReverseADJtprod(
84+
nvar::Integer,
85+
f,
86+
ncon::Integer = 0,
87+
c::Function = (args...) -> [];
88+
kwargs...,
89+
)
90+
x = zeros(nvar)
91+
return EnzymeReverseADJtprod(x)
92+
end
93+
94+
function Jtvprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
95+
Enzyme.autodiff(Enzyme.Reverse, c!, Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
96+
return Jtv
97+
end
98+
99+
struct EnzymeReverseADHprod <: InPlaceADBackend
100+
grad::Vector{Float64}
101+
end
102+
103+
function EnzymeReverseADHvprod(
104+
nvar::Integer,
105+
f,
106+
ncon::Integer = 0,
107+
c!::Function = (args...) -> [];
108+
x0::AbstractVector{T} = rand(nvar),
109+
kwargs...,
110+
) where {T}
111+
grad = zeros(nvar)
112+
return EnzymeReverseADHprod(grad)
113+
end
114+
115+
function Hvprod!(b::EnzymeReverseADHvprod, Hv, x, v, f, args...)
116+
# What to do with args?
117+
Enzyme.autodiff(
118+
Forward,
119+
gradient!,
120+
Const(Reverse),
121+
DuplicatedNoNeed(b.grad, Hv),
122+
Const(f),
123+
Duplicated(x, v),
124+
)
125+
return Hv
126+
end
127+
128+
function Hvprod!(
129+
b::EnzymeReverseADHvprod,
130+
Hv,
131+
x::AbstractVector{T},
132+
v,
133+
ℓ,
134+
::Val{:lag},
135+
y,
136+
obj_weight::Real = one(T),
137+
)
138+
Enzyme.autodiff(
139+
Forward,
140+
gradient!,
141+
Const(Reverse),
142+
DuplicatedNoNeed(b.grad, Hv),
143+
Const(ℓ),
144+
Duplicated(x, v),
145+
Const(y),
146+
)
147+
148+
return Hv
149+
end
150+
151+
function Hvprod!(
152+
b::EnzymeReverseADHvprod{T, S, Tagf},
153+
Hv,
154+
x,
155+
v,
156+
f,
157+
::Val{:obj},
158+
obj_weight::Real = one(T),
159+
)
160+
Enzyme.autodiff(
161+
Forward,
162+
gradient!,
163+
Const(Reverse),
164+
DuplicatedNoNeed(b.grad, Hv),
165+
Const(f),
166+
Duplicated(x, v),
167+
Const(y),
168+
)
169+
return Hv
21170
end

0 commit comments

Comments
 (0)