Skip to content

Commit 2e28370

Browse files
committed
cont.
1 parent 2981213 commit 2e28370

File tree

4 files changed

+83
-8
lines changed

4 files changed

+83
-8
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
ManualNLPModels = "30dfa513-9b2f-4fb3-9796-781eabac1617"
1011
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
12+
NLPModelsModifiers = "e01155f1-5c6f-4375-a9d8-616dd036575f"
13+
NLPModelsTest = "7998695d-6960-4d3a-85c4-e1bceb8cd856"
1114
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1215
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1316
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

src/ADNLPModels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using LinearAlgebra, SparseArrays
77
using ADTypes: ADTypes, AbstractColoringAlgorithm, AbstractSparsityDetector
88
using SparseConnectivityTracer: TracerSparsityDetector
99
using SparseMatrixColorings
10-
using ForwardDiff, ReverseDiff
10+
using ForwardDiff, ReverseDiff, Enzyme
1111

1212
# JSO
1313
using NLPModels

src/enzyme.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ function EnzymeReverseADGradient(
1414
return EnzymeReverseADGradient()
1515
end
1616

17+
function ADNLPModels.gradient(::EnzymeReverseADGradient, f, x)
18+
g = similar(x)
19+
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
20+
return g
21+
end
22+
1723
function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x)
1824
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
1925
return g
@@ -47,11 +53,13 @@ end
4753
function hessian(::EnzymeReverseADHessian, f, x)
4854
seed = similar(x)
4955
hess = zeros(eltype(x), length(x), length(x))
50-
fill!(seed, zero(x))
56+
fill!(seed, zero(eltype(x)))
57+
tmp = similar(x)
5158
for i in 1:length(x)
52-
seed[i] = one(x)
53-
Enzyme.hvp!(view(hess, i, :), f, x, seed)
54-
seed[i] = zero(x)
59+
seed[i] = one(eltype(seed))
60+
Enzyme.hvp!(tmp, f, x, seed)
61+
hess[:, i] .= tmp
62+
seed[i] = zero(eltype(seed))
5563
end
5664
return hess
5765
end
@@ -72,7 +80,9 @@ function EnzymeReverseADJprod(
7280
end
7381

7482
function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val)
75-
Enzyme.autodiff(Enzyme.Forward, c!, Duplicated(b.x, Jv), Enzyme.Duplicated(x, v))
83+
@show c!(x)
84+
@show Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(x, v))
85+
error("This is BAD")
7686
return Jv
7787
end
7888

@@ -91,7 +101,7 @@ function EnzymeReverseADJtprod(
91101
return EnzymeReverseADJtprod(x)
92102
end
93103

94-
function Jtvprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
104+
function Jtprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
95105
Enzyme.autodiff(Enzyme.Reverse, c!, Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
96106
return Jtv
97107
end

test/enzyme.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,70 @@ for problem in NLPModelsTest.nls_problems
1010
include("nls/problems/$(lowercase(problem)).jl")
1111
end
1212

13+
EnzymeReverseAD() = ADNLPModels.ADModelBackend(
14+
ADNLPModels.EnzymeReverseADGradient(),
15+
ADNLPModels.EnzymeReverseADHvprod(zeros(1)),
16+
ADNLPModels.EnzymeReverseADJprod(zeros(1)),
17+
ADNLPModels.EnzymeReverseADJtprod(zeros(1)),
18+
ADNLPModels.EnzymeReverseADJacobian(),
19+
ADNLPModels.EnzymeReverseADHessian(),
20+
ADNLPModels.EnzymeReverseADHvprod(zeros(1)),
21+
ADNLPModels.EmptyADbackend(),
22+
ADNLPModels.EmptyADbackend(),
23+
ADNLPModels.EmptyADbackend(),
24+
ADNLPModels.EmptyADbackend(),
25+
ADNLPModels.EmptyADbackend(),
26+
)
27+
28+
function test_autodiff_backend_error()
29+
@testset "Error without loading package - $backend" for backend in [:EnzymeReverseAD]
30+
adbackend = eval(backend)()
31+
# @test_throws ArgumentError gradient(adbackend.gradient_backend, sum, [1.0])
32+
# @test_throws ArgumentError gradient!(adbackend.gradient_backend, [1.0], sum, [1.0])
33+
# @test_throws ArgumentError jacobian(adbackend.jacobian_backend, identity, [1.0])
34+
# @test_throws ArgumentError hessian(adbackend.hessian_backend, sum, [1.0])
35+
# @test_throws ArgumentError Jprod!(
36+
# adbackend.jprod_backend,
37+
# [1.0],
38+
# [1.0],
39+
# identity,
40+
# [1.0],
41+
# Val(:c),
42+
# )
43+
# @test_throws ArgumentError Jtprod!(
44+
# adbackend.jtprod_backend,
45+
# [1.0],
46+
# [1.0],
47+
# identity,
48+
# [1.0],
49+
# Val(:c),
50+
# )
51+
gradient(adbackend.gradient_backend, sum, [1.0])
52+
gradient!(adbackend.gradient_backend, [1.0], sum, [1.0])
53+
jacobian(adbackend.jacobian_backend, identity, [1.0])
54+
hessian(adbackend.hessian_backend, sum, [1.0])
55+
Jprod!(
56+
adbackend.jprod_backend,
57+
[1.0],
58+
identity,
59+
[1.0],
60+
[1.0],
61+
Val(:c),
62+
)
63+
# Jtprod!(
64+
# adbackend.jtprod_backend,
65+
# [1.0],
66+
# identity,
67+
# [1.0],
68+
# [1.0],
69+
# Val(:c),
70+
# )
71+
end
72+
end
73+
74+
test_autodiff_backend_error()
1375
#=
14-
ADNLPModels.EmptyADbackend(args...; kwargs...) = ADNLPModels.EmptyADbackend()
76+
# ADNLPModels.EmptyADbackend(args...; kwargs...) = ADNLPModels.EmptyADbackend()
1577
1678
names = OptimizationProblems.meta[!, :name]
1779
list_excluded_enzyme = [

0 commit comments

Comments
 (0)