Skip to content

Commit 25d8db4

Browse files
committed
Update the sparse Hessian
1 parent 63adce3 commit 25d8db4

File tree

1 file changed

+55
-46
lines changed

1 file changed

+55
-46
lines changed

src/enzyme.jl

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,29 @@ function EnzymeReverseADJacobian(
2323
return EnzymeReverseADJacobian()
2424
end
2525

26-
struct EnzymeReverseADHessian <: ADBackend end
26+
struct EnzymeReverseADHessian{T} <: ADBackend
27+
seed::Vector{T}
28+
Hv::Vector{T}
29+
end
2730

2831
function EnzymeReverseADHessian(
2932
nvar::Integer,
30-
3133
f,
3234
ncon::Integer = 0,
3335
c::Function = (args...) -> [];
36+
x0::AbstractVector{T} = rand(nvar),
3437
kwargs...,
3538
)
3639
@assert nvar > 0
3740
nnzh = nvar * (nvar + 1) / 2
38-
return EnzymeReverseADHessian()
41+
42+
seed = zeros(T, nvar)
43+
Hv = zeros(T, nvar)
44+
return EnzymeReverseADHessian(seed, Hv)
3945
end
4046

41-
struct EnzymeReverseADHvprod <: InPlaceADbackend
42-
grad::Vector{Float64}
47+
struct EnzymeReverseADHvprod{T} <: InPlaceADbackend
48+
grad::Vector{T}
4349
end
4450

4551
function EnzymeReverseADHvprod(
@@ -50,12 +56,12 @@ function EnzymeReverseADHvprod(
5056
x0::AbstractVector{T} = rand(nvar),
5157
kwargs...,
5258
) where {T}
53-
grad = zeros(nvar)
59+
grad = zeros(T, nvar)
5460
return EnzymeReverseADHvprod(grad)
5561
end
5662

57-
struct EnzymeReverseADJprod <: InPlaceADbackend
58-
x::Vector{Float64}
63+
struct EnzymeReverseADJprod{T} <: InPlaceADbackend
64+
cx::Vector{T}
5965
end
6066

6167
function EnzymeReverseADJprod(
@@ -65,12 +71,12 @@ function EnzymeReverseADJprod(
6571
c::Function = (args...) -> [];
6672
kwargs...,
6773
)
68-
x = zeros(nvar)
69-
return EnzymeReverseADJprod(x)
74+
cx = zeros(T, nvar)
75+
return EnzymeReverseADJprod(cx)
7076
end
7177

72-
struct EnzymeReverseADJtprod <: InPlaceADbackend
73-
x::Vector{Float64}
78+
struct EnzymeReverseADJtprod{T} <: InPlaceADbackend
79+
cx::Vector{T}
7480
end
7581

7682
function EnzymeReverseADJtprod(
@@ -80,8 +86,8 @@ function EnzymeReverseADJtprod(
8086
c::Function = (args...) -> [];
8187
kwargs...,
8288
)
83-
x = zeros(nvar)
84-
return EnzymeReverseADJtprod(x)
89+
cx = zeros(T, nvar)
90+
return EnzymeReverseADJtprod(cx)
8591
end
8692

8793
struct SparseEnzymeADJacobian{R, C, S} <: ADBackend
@@ -93,7 +99,7 @@ struct SparseEnzymeADJacobian{R, C, S} <: ADBackend
9399
result_coloring::C
94100
compressed_jacobian::S
95101
v::Vector{R}
96-
buffer::Vector{R}
102+
cx::Vector{R}
97103
end
98104

99105
function SparseEnzymeADJacobian(
@@ -130,7 +136,7 @@ function SparseEnzymeADJacobian(
130136
nzval = T.(J.nzval)
131137
compressed_jacobian = similar(x0, ncon)
132138
v = similar(x0)
133-
buffer = zeros(T, ncon)
139+
cx = zeros(T, ncon)
134140

135141
SparseEnzymeADJacobian(
136142
nvar,
@@ -141,7 +147,7 @@ function SparseEnzymeADJacobian(
141147
result_coloring,
142148
compressed_jacobian,
143149
v,
144-
buffer,
150+
cx,
145151
)
146152
end
147153

@@ -152,11 +158,12 @@ struct SparseEnzymeADHessian{R, C, S, L} <: ADNLPModels.ADBackend
152158
nzval::Vector{R}
153159
result_coloring::C
154160
coloring_mode::Symbol
161+
compressed_hessian_icol::Vector{R}
155162
compressed_hessian::S
156163
v::Vector{R}
157164
y::Vector{R}
158165
grad::Vector{R}
159-
buffer::Vector{R}
166+
cx::Vector{R}
160167
::L
161168
end
162169

@@ -193,18 +200,20 @@ function SparseEnzymeADHessian(
193200
nzval = T.(trilH.nzval)
194201
if coloring_algorithm isa GreedyColoringAlgorithm{:direct}
195202
coloring_mode = :direct
196-
compressed_hessian = similar(x0)
203+
compressed_hessian_icol = similar(x0)
204+
compressed_hessian = compressed_hessian_icol
197205
else
198206
coloring_mode = :substitution
199207
group = column_groups(result_coloring)
200208
ncolors = length(group)
209+
compressed_hessian_icol = similar(x0)
201210
compressed_hessian = similar(x0, (nvar, ncolors))
202211
end
203212
v = similar(x0)
204213
y = similar(x0, ncon)
205-
buffer = similar(x0, ncon)
214+
cx = similar(x0, ncon)
206215
grad = similar(x0)
207-
ℓ(x, y, obj_weight, buffer) = obj_weight * f(x) + dot(c!(buffer, x), y)
216+
ℓ(x, y, obj_weight, cx) = obj_weight * f(x) + dot(c!(cx, x), y)
208217

209218
return SparseEnzymeADHessian(
210219
nvar,
@@ -213,11 +222,12 @@ function SparseEnzymeADHessian(
213222
nzval,
214223
result_coloring,
215224
coloring_mode,
225+
compressed_hessian_icol,
216226
compressed_hessian,
217227
v,
218228
y,
219229
grad,
220-
buffer,
230+
cx,
221231
ℓ,
222232
)
223233
end
@@ -238,27 +248,27 @@ end
238248

239249
jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Enzyme.Reverse, f, x)
240250

241-
function hessian(::EnzymeReverseADHessian, f, x)
242-
seed = similar(x)
243-
hess = zeros(eltype(x), length(x), length(x))
244-
fill!(seed, zero(eltype(x)))
245-
tmp = similar(x)
246-
for i in 1:length(x)
247-
seed[i] = one(eltype(seed))
248-
Enzyme.hvp!(tmp, Enzyme.Const(f), x, seed)
249-
hess[:, i] .= tmp
250-
seed[i] = zero(eltype(seed))
251+
function hessian(b::EnzymeReverseADHessian, f, x)
252+
T = eltype(x)
253+
n = length(x)
254+
hess = zeros(T, n, n)
255+
fill!(b.seed, zero(T))
256+
for i in 1:n
257+
seed[i] = one(T)
258+
Enzyme.hvp!(b.Hv, Enzyme.Const(f), x, b.seed)
259+
view(hess, :, i) .= b.Hv
260+
seed[i] = zero(T)
251261
end
252262
return hess
253263
end
254264

255265
function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val)
256-
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jv), Enzyme.Duplicated(x, v))
266+
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.cx, Jv), Enzyme.Duplicated(x, v))
257267
return Jv
258268
end
259269

260270
function Jtprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
261-
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
271+
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(c!), Enzyme.Duplicated(b.cx, Jtv), Enzyme.Duplicated(x, v))
262272
return Jtv
263273
end
264274

@@ -343,7 +353,7 @@ end
343353

344354
# b.compressed_jacobian is just a vector Jv here
345355
# We don't use the vector mode
346-
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.buffer, b.compressed_jacobian), Enzyme.Duplicated(x, b.v))
356+
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.cx, b.compressed_jacobian), Enzyme.Duplicated(x, b.v))
347357

348358
# Update the columns of the Jacobian that have the color `icol`
349359
decompress_single_color!(A, b.compressed_jacobian, icol, b.result_coloring)
@@ -425,11 +435,7 @@ end
425435
b.v[col] = 1
426436
end
427437

428-
# column icol of the compressed hessian
429-
compressed_hessian_icol =
430-
(b.coloring_mode == :direct) ? b.compressed_hessian : view(b.compressed_hessian, :, icol)
431-
432-
function _gradient!(dx, f, x, y, obj_weight, buffer)
438+
function _gradient!(dx, f, x, y, obj_weight, cx)
433439
Enzyme.make_zero!(dx)
434440
res = Enzyme.autodiff(
435441
Enzyme.Reverse,
@@ -438,12 +444,12 @@ end
438444
Enzyme.Duplicated(x, dx),
439445
Enzyme.Const(y),
440446
Enzyme.Const(obj_weight),
441-
Enzyme.Const(buffer)
447+
Enzyme.Const(cx)
442448
)
443449
return nothing
444450
end
445451

446-
function _hvp!(res, f, x, v, y, obj_weight, buffer)
452+
function _hvp!(res, f, x, v, y, obj_weight, cx)
447453
# grad = Enzyme.make_zero(x)
448454
Enzyme.autodiff(
449455
Enzyme.Forward,
@@ -453,19 +459,22 @@ end
453459
Enzyme.Duplicated(x, v),
454460
Enzyme.Const(y),
455461
Enzyme.Const(obj_weight),
456-
Enzyme.Const(buffer),
462+
Enzyme.Const(cx),
457463
)
458464
return nothing
459465
end
460466

461467
_hvp!(
462-
Enzyme.DuplicatedNoNeed(b.grad, compressed_hessian_icol),
463-
b.ℓ, x, b.v, y, obj_weight, b.buffer
468+
Enzyme.DuplicatedNoNeed(b.grad, b.compressed_hessian_icol),
469+
b.ℓ, x, b.v, y, obj_weight, b.cx
464470
)
465471

466472
if b.coloring_mode == :direct
467473
# Update the coefficients of the lower triangular part of the Hessian that are related to the color `icol`
468-
decompress_single_color!(A, compressed_hessian_icol, icol, b.result_coloring, :L)
474+
decompress_single_color!(A, b.compressed_hessian_icol, icol, b.result_coloring, :L)
475+
end
476+
if b.coloring_mode == :substitution
477+
view(b.compressed_hessian, :, icol) .= b.compressed_hessian_icol
469478
end
470479
end
471480
if b.coloring_mode == :substitution

0 commit comments

Comments
 (0)