Skip to content

Commit 58eb3e7

Browse files
michel2323amontoison
authored andcommitted
Enzyme WIP
1 parent c4622d8 commit 58eb3e7

File tree

4 files changed

+69
-116
lines changed

4 files changed

+69
-116
lines changed

src/enzyme.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
struct EnzymeReverseADJacobian <: ADBackend end
22
struct EnzymeReverseADHessian <: ADBackend end
33

4-
struct EnzymeReverseADGradient <: ADNLPModels.ADBackend end
4+
struct EnzymeReverseADGradient <: InPlaceADbackend end
55

66
function EnzymeReverseADGradient(
77
nvar::Integer,
@@ -16,12 +16,13 @@ end
1616

1717
function ADNLPModels.gradient(::EnzymeReverseADGradient, f, x)
1818
g = similar(x)
19-
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
19+
# Enzyme.autodiff(Enzyme.Reverse, Const(f), Active, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
20+
Enzyme.gradient!(Reverse, g, Const(f), x)
2021
return g
2122
end
2223

2324
function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x)
24-
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
25+
Enzyme.autodiff(Enzyme.Reverse, Const(f), Active, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
2526
return g
2627
end
2728

@@ -57,7 +58,7 @@ function hessian(::EnzymeReverseADHessian, f, x)
5758
tmp = similar(x)
5859
for i in 1:length(x)
5960
seed[i] = one(eltype(seed))
60-
Enzyme.hvp!(tmp, f, x, seed)
61+
Enzyme.hvp!(tmp, Const(f), x, seed)
6162
hess[:, i] .= tmp
6263
seed[i] = zero(eltype(seed))
6364
end
@@ -80,9 +81,7 @@ function EnzymeReverseADJprod(
8081
end
8182

8283
function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val)
83-
@show c!(x)
84-
@show Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(x, v))
85-
error("This is BAD")
84+
Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(b.x,Jv), Duplicated(x, v))
8685
return Jv
8786
end
8887

@@ -102,7 +101,7 @@ function EnzymeReverseADJtprod(
102101
end
103102

104103
function Jtprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
105-
Enzyme.autodiff(Enzyme.Reverse, c!, Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
104+
Enzyme.autodiff(Enzyme.Reverse, Const(c!), Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
106105
return Jtv
107106
end
108107

@@ -126,7 +125,7 @@ function Hvprod!(b::EnzymeReverseADHvprod, Hv, x, v, f, args...)
126125
# What to do with args?
127126
Enzyme.autodiff(
128127
Forward,
129-
gradient!,
128+
Const(Enzyme.gradient!),
130129
Const(Reverse),
131130
DuplicatedNoNeed(b.grad, Hv),
132131
Const(f),
@@ -147,7 +146,7 @@ function Hvprod!(
147146
)
148147
Enzyme.autodiff(
149148
Forward,
150-
gradient!,
149+
Const(Enzyme.gradient!),
151150
Const(Reverse),
152151
DuplicatedNoNeed(b.grad, Hv),
153152
Const(ℓ),
@@ -169,7 +168,7 @@ function Hvprod!(
169168
)
170169
Enzyme.autodiff(
171170
Forward,
172-
gradient!,
171+
Const(Enzyme.gradient!),
173172
Const(Reverse),
174173
DuplicatedNoNeed(b.grad, Hv),
175174
Const(f),

test/enzyme.jl

Lines changed: 37 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ EnzymeReverseAD() = ADNLPModels.ADModelBackend(
2424
ADNLPModels.EmptyADbackend(),
2525
ADNLPModels.EmptyADbackend(),
2626
)
27-
27+
function mysum!(y, x)
28+
sum!(y, x)
29+
return nothing
30+
end
2831
function test_autodiff_backend_error()
2932
@testset "Error without loading package - $backend" for backend in [:EnzymeReverseAD]
3033
adbackend = eval(backend)()
@@ -50,100 +53,50 @@ function test_autodiff_backend_error()
5053
# )
5154
gradient(adbackend.gradient_backend, sum, [1.0])
5255
gradient!(adbackend.gradient_backend, [1.0], sum, [1.0])
53-
jacobian(adbackend.jacobian_backend, identity, [1.0])
56+
jacobian(adbackend.jacobian_backend, sum, [1.0])
5457
hessian(adbackend.hessian_backend, sum, [1.0])
5558
Jprod!(
5659
adbackend.jprod_backend,
5760
[1.0],
58-
identity,
61+
sum!,
62+
[1.0],
63+
[1.0],
64+
Val(:c),
65+
)
66+
Jtprod!(
67+
adbackend.jtprod_backend,
68+
[1.0],
69+
mysum!,
5970
[1.0],
6071
[1.0],
6172
Val(:c),
6273
)
63-
# Jtprod!(
64-
# adbackend.jtprod_backend,
65-
# [1.0],
66-
# identity,
67-
# [1.0],
68-
# [1.0],
69-
# Val(:c),
70-
# )
7174
end
7275
end
7376

7477
test_autodiff_backend_error()
75-
#=
76-
# ADNLPModels.EmptyADbackend(args...; kwargs...) = ADNLPModels.EmptyADbackend()
7778

78-
names = OptimizationProblems.meta[!, :name]
79-
list_excluded_enzyme = [
80-
"brybnd",
81-
"clplatea",
82-
"clplateb",
83-
"clplatec",
84-
"curly",
85-
"curly10",
86-
"curly20",
87-
"curly30",
88-
"elec",
89-
"fminsrf2",
90-
"hs101",
91-
"hs117",
92-
"hs119",
93-
"hs86",
94-
"integreq",
95-
"ncb20",
96-
"ncb20b",
97-
"palmer1c",
98-
"palmer1d",
99-
"palmer2c",
100-
"palmer3c",
101-
"palmer4c",
102-
"palmer5c",
103-
"palmer5d",
104-
"palmer6c",
105-
"palmer7c",
106-
"palmer8c",
107-
"sbrybnd",
108-
"tetra",
109-
"tetra_duct12",
110-
"tetra_duct15",
111-
"tetra_duct20",
112-
"tetra_foam5",
113-
"tetra_gear",
114-
"tetra_hook",
115-
"threepk",
116-
"triangle",
117-
"triangle_deer",
118-
"triangle_pacman",
119-
"triangle_turtle",
120-
"watson",
121-
]
122-
for pb in names
123-
@info pb
124-
(pb in list_excluded_enzyme) && continue
125-
nlp = eval(Meta.parse(pb))(
126-
gradient_backend = ADNLPModels.EnzymeADGradient,
127-
jacobian_backend = ADNLPModels.EmptyADbackend,
128-
hessian_backend = ADNLPModels.EmptyADbackend,
129-
)
130-
grad(nlp, get_x0(nlp))
131-
end
132-
=#
79+
push!(
80+
ADNLPModels.predefined_backend,
81+
:enzyme_backend => Dict(
82+
:gradient_backend => ADNLPModels.EnzymeReverseADGradient,
83+
:jprod_backend => ADNLPModels.EnzymeReverseADJprod,
84+
:jtprod_backend => ADNLPModels.EnzymeReverseADJtprod,
85+
:hprod_backend => ADNLPModels.EnzymeReverseADHvprod,
86+
:jacobian_backend => ADNLPModels.EnzymeReverseADJacobian,
87+
:hessian_backend => ADNLPModels.EnzymeReverseADHessian,
88+
:ghjvprod_backend => ADNLPModels.ForwardDiffADGHjvprod,
89+
:jprod_residual_backend => ADNLPModels.EnzymeReverseADJprod,
90+
:jtprod_residual_backend => ADNLPModels.EnzymeReverseADJtprod,
91+
:hprod_residual_backend => ADNLPModels.EnzymeReverseADHvprod,
92+
:jacobian_residual_backend => ADNLPModels.EnzymeReverseADJacobian,
93+
:hessian_residual_backend => ADNLPModels.EnzymeReverseADHessian,
94+
),
95+
)
96+
97+
include("utils.jl")
98+
include("nlp/basic.jl")
99+
include("nls/basic.jl")
100+
include("nlp/nlpmodelstest.jl")
101+
include("nls/nlpmodelstest.jl")
133102

134-
#=
135-
ERROR: Duplicated Returns not yet handled
136-
Stacktrace:
137-
[1] autodiff
138-
@.julia\packages\Enzyme\DIkTv\src\Enzyme.jl:209 [inlined]
139-
[2] autodiff(mode::EnzymeCore.ReverseMode, f::OptimizationProblems.ADNLPProblems.var"#f#254"{OptimizationProblems.ADNLPProblems.var"#f#250#255"}, args::Duplicated{Vector{Float64}})
140-
@ Enzyme.julia\packages\Enzyme\DIkTv\src\Enzyme.jl:248
141-
[3] gradient!(#unused#::ADNLPModels.EnzymeADGradient, g::Vector{Float64}, f::Function, x::Vector{Float64})
142-
@ ADNLPModelsDocuments\cvs\ADNLPModels.jl\src\enzyme.jl:17
143-
[4] grad!(nlp::ADNLPModel{Float64, Vector{Float64}, Vector{Int64}}, x::Vector{Float64}, g::Vector{Float64})
144-
@ ADNLPModelsDocuments\cvs\ADNLPModels.jl\src\nlp.jl:542
145-
[5] grad(nlp::ADNLPModel{Float64, Vector{Float64}, Vector{Int64}}, x::Vector{Float64})
146-
@ NLPModels.julia\packages\NLPModels\XBcWL\src\nlp\api.jl:31
147-
[6] top-level scope
148-
@ .\REPL[7]:5
149-
=#

test/nlp/basic.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ function test_autodiff_model(name; kwargs...)
1717
nlp = ADNLPModel(f, x0, c, [0.0], [0.0]; kwargs...)
1818
@test obj(nlp, x0) == f(x0)
1919

20-
x = range(-1, stop = 1, length = 100)
21-
y = 2x .+ 3 + randn(100) * 0.1
22-
regr = LinearRegression(x, y)
23-
nlp = ADNLPModel(regr, ones(2); kwargs...)
24-
β = [ones(100) x] \ y
25-
@test abs(obj(nlp, β) - norm(y .- β[1] - β[2] * x)^2 / 2) < 1e-12
26-
@test norm(grad(nlp, β)) < 1e-12
20+
# x = range(-1, stop = 1, length = 100)
21+
# y = 2x .+ 3 + randn(100) * 0.1
22+
# regr = LinearRegression(x, y)
23+
# nlp = ADNLPModel(regr, ones(2); kwargs...)
24+
# β = [ones(100) x] \ y
25+
# @test abs(obj(nlp, β) - norm(y .- β[1] - β[2] * x)^2 / 2) < 1e-12
26+
# @test norm(grad(nlp, β)) < 1e-12
2727

2828
test_getter_setter(nlp)
2929

test/nlp/nlpmodelstest.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
@testset "Checking NLPModelsTest (NLP) tests with $backend" for backend in
2-
keys(ADNLPModels.predefined_backend)
1+
# @testset "Checking NLPModelsTest (NLP) tests with $backend" for backend in
2+
# keys(ADNLPModels.predefined_backend)
3+
backend = :enzyme_backend
34
@testset "Checking NLPModelsTest tests on problem $problem" for problem in
45
NLPModelsTest.nlp_problems
56
nlp_from_T = eval(Meta.parse(lowercase(problem) * "_autodiff"))
@@ -12,17 +13,17 @@
1213
@testset "Check Consistency" begin
1314
consistent_nlps(nlps, exclude = [], linear_api = true, reimplemented = ["jtprod"])
1415
end
15-
@testset "Check dimensions" begin
16-
check_nlp_dimensions(nlp_ad, exclude = [], linear_api = true)
17-
end
18-
@testset "Check multiple precision" begin
19-
multiple_precision_nlp(nlp_from_T, exclude = [], linear_api = true)
20-
end
21-
@testset "Check view subarray" begin
22-
view_subarray_nlp(nlp_ad, exclude = [])
23-
end
24-
@testset "Check coordinate memory" begin
25-
coord_memory_nlp(nlp_ad, exclude = [], linear_api = true)
26-
end
16+
# @testset "Check dimensions" begin
17+
# check_nlp_dimensions(nlp_ad, exclude = [], linear_api = true)
18+
# end
19+
# @testset "Check multiple precision" begin
20+
# multiple_precision_nlp(nlp_from_T, exclude = [], linear_api = true)
21+
# end
22+
# @testset "Check view subarray" begin
23+
# view_subarray_nlp(nlp_ad, exclude = [])
24+
# end
25+
# @testset "Check coordinate memory" begin
26+
# coord_memory_nlp(nlp_ad, exclude = [], linear_api = true)
27+
# end
2728
end
2829
end

0 commit comments

Comments
 (0)