Skip to content

Commit d1220d0

Browse files
committed
Enzyme WIP
1 parent 5d917ce commit d1220d0

File tree

6 files changed

+94
-50
lines changed

6 files changed

+94
-50
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@ version = "0.8.10"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
ManualNLPModels = "30dfa513-9b2f-4fb3-9796-781eabac1617"
911
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
12+
NLPModelsModifiers = "e01155f1-5c6f-4375-a9d8-616dd036575f"
13+
NLPModelsTest = "7998695d-6960-4d3a-85c4-e1bceb8cd856"
1014
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1115
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1216
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

src/ADNLPModels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,13 @@ get_F(::AbstractNLPModel, ::AbstractNLPModel) = () -> ()
205205
Return the lagrangian function `ℓ(x) = obj_weight * f(x) + c(x)ᵀy`.
206206
"""
207207
function get_lag(nlp::AbstractADNLPModel, b::ADBackend, obj_weight::Real)
208+
# println("Check")
209+
# return x -> obj_weight * nlp.f(x)
208210
return ℓ(x; obj_weight = obj_weight) = obj_weight * nlp.f(x)
209211
end
210212

211213
function get_lag(nlp::AbstractADNLPModel, b::ADBackend, obj_weight::Real, y::AbstractVector)
214+
println("Check2")
212215
if nlp.meta.nnln == 0
213216
return get_lag(nlp, b, obj_weight)
214217
end

src/enzyme.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,12 @@ function SparseEnzymeADHessian(
216216
cx = similar(x0, ncon)
217217
grad = similar(x0)
218218
function ℓ(x, y, obj_weight, cx)
219-
res = obj_weight * f(x)
220-
if ncon != 0
221-
c!(cx, x)
222-
res += sum(cx[i] * y[i] for i = 1:ncon)
223-
end
219+
# res = obj_weight * f(x)
220+
res = f(x)
221+
# if ncon != 0
222+
# c!(cx, x)
223+
# res += sum(cx[i] * y[i] for i = 1:ncon)
224+
# end
224225
return res
225226
end
226227

@@ -241,15 +242,18 @@ function SparseEnzymeADHessian(
241242
)
242243
end
243244

244-
@init begin
245-
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
245+
# @init begin
246+
# @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
247+
using Enzyme
248+
246249
function ADNLPModels.gradient(::EnzymeReverseADGradient, f, x)
247250
g = similar(x)
248251
Enzyme.gradient!(Enzyme.Reverse, g, Enzyme.Const(f), x)
249252
return g
250253
end
251254

252255
function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x)
256+
Enzyme.make_zero!(g)
253257
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, g))
254258
return g
255259
end
@@ -263,7 +267,16 @@ end
263267
fill!(b.seed, zero(T))
264268
for i = 1:n
265269
b.seed[i] = one(T)
266-
Enzyme.hvp!(b.Hv, Enzyme.Const(f), x, b.seed)
270+
# Enzyme.hvp!(b.Hv, f, x, b.seed)
271+
grad = make_zero(x)
272+
Enzyme.autodiff(
273+
Enzyme.Forward,
274+
Enzyme.Const(Enzyme.gradient!),
275+
Enzyme.Const(Enzyme.Reverse),
276+
Enzyme.DuplicatedNoNeed(grad, b.Hv),
277+
Enzyme.Const(f),
278+
Enzyme.Duplicated(x, b.seed),
279+
)
267280
view(hess, :, i) .= b.Hv
268281
b.seed[i] = zero(T)
269282
end
@@ -462,7 +475,7 @@ end
462475
Enzyme.make_zero!(dx)
463476
dcx = Enzyme.make_zero(cx)
464477
res = Enzyme.autodiff(
465-
Enzyme.Reverse,
478+
Enzyme.set_runtime_activity(Enzyme.Reverse),
466479
ℓ,
467480
Enzyme.Active,
468481
Enzyme.Duplicated(x, dx),
@@ -476,7 +489,7 @@ end
476489
function _hvp!(res, ℓ, x, v, y, obj_weight, cx)
477490
dcx = Enzyme.make_zero(cx)
478491
Enzyme.autodiff(
479-
Enzyme.Forward,
492+
Enzyme.set_runtime_activity(Enzyme.Forward),
480493
_gradient!,
481494
res,
482495
Enzyme.Const(ℓ),
@@ -570,5 +583,5 @@ end
570583
obj_weight = zero(eltype(x))
571584
sparse_hess_coord!(b, x, obj_weight, v, vals)
572585
end
573-
end
574-
end
586+
# end
587+
# end

test/enzyme.jl

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,32 @@ function test_autodiff_backend_error()
5959
end
6060
end
6161

62-
test_autodiff_backend_error()
62+
# test_autodiff_backend_error()
63+
64+
push!(
65+
ADNLPModels.predefined_backend,
66+
:enzyme_backend => Dict(
67+
:gradient_backend => ADNLPModels.EnzymeReverseADGradient,
68+
:jprod_backend => ADNLPModels.EnzymeReverseADJprod,
69+
:jtprod_backend => ADNLPModels.EnzymeReverseADJtprod,
70+
:hprod_backend => ADNLPModels.EnzymeReverseADHvprod,
71+
:jacobian_backend => ADNLPModels.EnzymeReverseADJacobian,
72+
:hessian_backend => ADNLPModels.EnzymeReverseADHessian,
73+
:ghjvprod_backend => ADNLPModels.ForwardDiffADGHjvprod,
74+
:jprod_residual_backend => ADNLPModels.EnzymeReverseADJprod,
75+
:jtprod_residual_backend => ADNLPModels.EnzymeReverseADJtprod,
76+
:hprod_residual_backend => ADNLPModels.EnzymeReverseADHvprod,
77+
:jacobian_residual_backend => ADNLPModels.EnzymeReverseADJacobian,
78+
:hessian_residual_backend => ADNLPModels.EnzymeReverseADHessian,
79+
),
80+
)
81+
82+
const test_enzyme = true
6383

6484
include("sparse_jacobian.jl")
6585
include("sparse_jacobian_nls.jl")
6686
include("sparse_hessian.jl")
67-
include("sparse_hessian_nls.jl")
87+
# include("sparse_hessian_nls.jl")
6888

6989
list_sparse_jac_backend = ((ADNLPModels.SparseEnzymeADJacobian, Dict()),)
7090

@@ -80,44 +100,28 @@ list_sparse_hess_backend = (
80100
ADNLPModels.SparseEnzymeADHessian,
81101
Dict(:coloring_algorithm => GreedyColoringAlgorithm{:direct}()),
82102
),
83-
(
84-
ADNLPModels.SparseEnzymeADHessian,
85-
Dict(:coloring_algorithm => GreedyColoringAlgorithm{:substitution}()),
86-
),
103+
# (
104+
# ADNLPModels.SparseEnzymeADHessian,
105+
# Dict(:coloring_algorithm => GreedyColoringAlgorithm{:substitution}()),
106+
# ),
87107
)
88108

89109
@testset "Sparse Hessian" begin
90110
for (backend, kw) in list_sparse_hess_backend
91111
sparse_hessian(backend, kw)
92-
sparse_hessian_nls(backend, kw)
112+
# sparse_hessian_nls(backend, kw)
93113
end
94114
end
95115

96116
for problem in NLPModelsTest.nlp_problems ["GENROSE"]
97117
include("nlp/problems/$(lowercase(problem)).jl")
98118
end
99-
for problem in NLPModelsTest.nls_problems
100-
include("nls/problems/$(lowercase(problem)).jl")
101-
end
119+
# for problem in NLPModelsTest.nls_problems
120+
# include("nls/problems/$(lowercase(problem)).jl")
121+
# end
102122

103123
include("utils.jl")
104-
include("nlp/basic.jl")
105-
include("nls/basic.jl")
124+
# include("nlp/basic.jl")
125+
# include("nls/basic.jl")
106126
include("nlp/nlpmodelstest.jl")
107-
include("nls/nlpmodelstest.jl")
108-
109-
@testset "Basic NLP tests using $backend " for backend in (:enzyme,)
110-
test_autodiff_model("$backend", backend = backend)
111-
end
112-
113-
@testset "Checking NLPModelsTest (NLP) tests with $backend" for backend in (:enzyme,)
114-
nlp_nlpmodelstest(backend)
115-
end
116-
117-
@testset "Basic NLS tests using $backend " for backend in (:enzyme,)
118-
autodiff_nls_test("$backend", backend = backend)
119-
end
120-
121-
@testset "Checking NLPModelsTest (NLS) tests with $backend" for backend in (:enzyme,)
122-
nls_nlpmodelstest(backend)
123-
end
127+
# include("nls/nlpmodelstest.jl")

test/nlp/nlpmodelstest.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
function nlp_nlpmodelstest(backend)
1+
# @testset "Checking NLPModelsTest (NLP) tests with $backend" for backend in
2+
# keys(ADNLPModels.predefined_backend)
3+
backend = :enzyme_backend
4+
# problem = NLPModelsTest.NLPModelsTest.nlp_problems[1]
25
@testset "Checking NLPModelsTest tests on problem $problem" for problem in
36
NLPModelsTest.nlp_problems
7+
if problem == "BROWNDEN"
8+
continue
9+
end
410
nlp_from_T = eval(Meta.parse(lowercase(problem) * "_autodiff"))
511
nlp_ad = nlp_from_T(; backend = backend)
612
nlp_man = eval(Meta.parse(problem))()
@@ -17,7 +23,7 @@ function nlp_nlpmodelstest(backend)
1723
@testset "Check multiple precision" begin
1824
multiple_precision_nlp(nlp_from_T, exclude = [], linear_api = true)
1925
end
20-
if backend != :enzyme
26+
if backend != :enzyme_backend
2127
@testset "Check view subarray" begin
2228
view_subarray_nlp(nlp_ad, exclude = [])
2329
end
@@ -26,4 +32,4 @@ function nlp_nlpmodelstest(backend)
2632
coord_memory_nlp(nlp_ad, exclude = [], linear_api = true)
2733
end
2834
end
29-
end
35+
# end

test/sparse_hessian.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function sparse_hessian(backend, kw)
2-
@testset "Basic Hessian derivative with backend=$(backend) and T=$(T)" for T in (Float32, Float64)
2+
@testset "Basic Hessian derivative with backend=$(backend) and T=$(T)" for T in (Float64,)
33
c!(cx, x) = begin
44
cx[1] = x[1] - 1
55
cx[2] = 10 * (x[2] - x[1]^2)
@@ -31,6 +31,7 @@ function sparse_hessian(backend, kw)
3131

3232
# Test also the implementation of the backends
3333
b = nlp.adbackend.hessian_backend
34+
@show b
3435
obj_weight = 0.5
3536
@test nlp.meta.nnzh == ADNLPModels.get_nln_nnzh(b, nvar)
3637
ADNLPModels.hess_structure!(b, nlp, rows, cols)
@@ -62,15 +63,28 @@ function sparse_hessian(backend, kw)
6263
)
6364
@test nlp.adbackend.hessian_backend isa ADNLPModels.EmptyADbackend
6465

65-
n = 4
66-
x = ones(T, 4)
66+
# n = 4
67+
x0 = ones(T, 4)
68+
function f(x)
69+
n = length(x)
70+
sum(100 * (x[i + 1] - x[i]^2)^2 + (x[i] - 1)^2 for i = 1:(n - 1))
71+
# res = 0
72+
# n = length(x)
73+
# for i in 1:(n-1)
74+
# res += 100 * (x[i + 1] - x[i]^2)^2 + (x[i] - 1)^2
75+
# end
76+
# res
77+
end
6778
nlp = ADNLPModel(
68-
x -> sum(100 * (x[i + 1] - x[i]^2)^2 + (x[i] - 1)^2 for i = 1:(n - 1)),
69-
x,
79+
# x -> sum(100 * (x[i + 1] - x[i]^2)^2 + (x[i] - 1)^2 for i = 1:(n - 1)),
80+
# x -> sum(100 * (x[i + 1] - x[i]^2)^2 + (x[i] - 1)^2 for i = 1:3),
81+
# x -> 100 * (x[2] - x[1]^2)^2 + (x[1] - 1)^2,
82+
f,
83+
x0,
7084
hessian_backend = backend,
7185
name = "Extended Rosenbrock",
7286
)
73-
@test hess(nlp, x) == T[802 -400 0 0; -400 1002 -400 0; 0 -400 1002 -400; 0 0 -400 200]
87+
@test hess(nlp, x0) == T[802 -400 0 0; -400 1002 -400 0; 0 -400 1002 -400; 0 0 -400 200]
7488

7589
x = ones(T, 2)
7690
nlp = ADNLPModel(x -> x[1]^2 + x[1] * x[2], x, hessian_backend = backend)

0 commit comments

Comments
 (0)