Skip to content

Commit ae65c66

Browse files
committed
Fix the tests with Zygote
1 parent 9c10a5c commit ae65c66

File tree

3 files changed

+27
-26
lines changed

3 files changed

+27
-26
lines changed

test/runtests.jl

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,32 +51,7 @@ ReverseDiffAD(nvar, f) = ADNLPModels.ADModelBackend(
5151
hessian_backend = ADNLPModels.ReverseDiffADHessian,
5252
)
5353

54-
function test_getter_setter(nlp)
55-
@test get_adbackend(nlp) == nlp.adbackend
56-
if typeof(nlp) <: ADNLPModel
57-
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, nlp.f))
58-
elseif typeof(nlp) <: ADNLSModel
59-
function F(x; nequ = nlp.nls_meta.nequ)
60-
Fx = similar(x, nequ)
61-
nlp.F!(Fx, x)
62-
return Fx
63-
end
64-
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, x -> sum(F(x) .^ 2)))
65-
end
66-
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ReverseDiffADGradient
67-
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
68-
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
69-
set_adbackend!(
70-
nlp,
71-
gradient_backend = ADNLPModels.ForwardDiffADGradient,
72-
jtprod_backend = ADNLPModels.GenericForwardDiffADJtprod(),
73-
)
74-
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ForwardDiffADGradient
75-
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
76-
@test typeof(get_adbackend(nlp).jtprod_backend) <: ADNLPModels.GenericForwardDiffADJtprod
77-
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
78-
end
79-
54+
include("utils.jl")
8055
include("nlp/basic.jl")
8156
include("nls/basic.jl")
8257
include("nlp/nlpmodelstest.jl")

test/utils.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
function test_getter_setter(nlp)
2+
@test get_adbackend(nlp) == nlp.adbackend
3+
if typeof(nlp) <: ADNLPModel
4+
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, nlp.f))
5+
elseif typeof(nlp) <: ADNLSModel
6+
function F(x; nequ = nlp.nls_meta.nequ)
7+
Fx = similar(x, nequ)
8+
nlp.F!(Fx, x)
9+
return Fx
10+
end
11+
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, x -> sum(F(x) .^ 2)))
12+
end
13+
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ReverseDiffADGradient
14+
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
15+
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
16+
set_adbackend!(
17+
nlp,
18+
gradient_backend = ADNLPModels.ForwardDiffADGradient,
19+
jtprod_backend = ADNLPModels.GenericForwardDiffADJtprod(),
20+
)
21+
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ForwardDiffADGradient
22+
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
23+
@test typeof(get_adbackend(nlp).jtprod_backend) <: ADNLPModels.GenericForwardDiffADJtprod
24+
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
25+
end

test/zygote.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ ADNLPModels.predefined_backend = Dict(
6868
# Automatically loads the code for Zygote with Requires
6969
import Zygote
7070

71+
include("utils.jl")
7172
include("nlp/basic.jl")
7273
include("nls/basic.jl")
7374
include("nlp/nlpmodelstest.jl")

0 commit comments

Comments
 (0)