Skip to content

Commit cb0668a

Browse files
Update runtests.jl
Tests match the latest ODE code with callbacks and maxiters.
1 parent a2d406e commit cb0668a

File tree

1 file changed

+26
-63
lines changed

1 file changed

+26
-63
lines changed
Lines changed: 26 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,47 @@
11
using Test
2-
using Optimization
3-
using Optimization.SciMLBase
2+
using Optimization, Optimization.SciMLBase
43
using Optimization.ADTypes
54
using OptimizationODE
5+
using LinearAlgebra
66

7-
quad(u, p) = u[1]^2 + p[1]*u[2]^2
8-
function quad_grad!(g, u, p, data)
9-
g[1] = 2u[1]
10-
g[2] = 2p[1]*u[2]
11-
return g
12-
end
13-
14-
rosen(u, p) = (p[1] - u[1])^2 + p[2]*(u[2] - u[1]^2)^2
15-
function rosen_grad!(g, u, p, data)
16-
g[1] = -2*(p[1] - u[1]) - 4*p[2]*u[1]*(u[2] - u[1]^2)
17-
g[2] = 2*p[2]*(u[2] - u[1]^2)
18-
return g
19-
end
20-
21-
make_zeros(u) = zero.(u)
22-
make_ones(u) = fill(one(eltype(u)), length(u))
23-
24-
@testset "OptimizationODE: Steady‐State Solvers" begin
25-
26-
u0q = [2.0, -3.0]
27-
pq = [5.0]
28-
fq = OptimizationFunction(quad, SciMLBase.NoAD(); grad = quad_grad!)
29-
probQ_noad = OptimizationProblem(fq, u0q, pq)
30-
31-
u0r = [-1.2, 1.0]
32-
pr = [1.0, 100.0]
7+
@testset "OptimizationODE Tests" begin
338

34-
ADmodes = (
35-
(SciMLBase.NoAD(), "NoAD", nothing),
36-
(ADTypes.AutoForwardDiff(), "ForwardDiff", nothing)
37-
)
38-
39-
40-
@testset "ODEGradientDescent on Quadratic" begin
41-
sol = solve(probQ_noad, ODEGradientDescent; dt = 0.1, maxiters = 2000)
42-
@test isapprox(sol.u, make_zeros(sol.u); atol = 1e-2)
43-
@test sol.retcode == ReturnCode.Success
9+
function f(x, p, args...)
10+
return sum(abs2, x)
4411
end
4512

46-
@testset "ODEGradientDescent on Rosenbrock" begin
47-
for (ad, name, _) in ADmodes
48-
fr = OptimizationFunction(rosen, ad; grad = rosen_grad!)
49-
probR = OptimizationProblem(fr, u0r, pr)
50-
sol = solve(probR, ODEGradientDescent; dt = 0.001, maxiters = 3000)
51-
@test isapprox(sol.u, make_ones(sol.u); atol = 0.01)
52-
@test sol.retcode == ReturnCode.Success
53-
end
13+
function g!(g, x, p, args...)
14+
@. g = 2 * x
5415
end
5516

17+
x0 = [2.0, -3.0]
18+
p = [5.0]
5619

57-
@testset "RKChebyshevDescent on Quadratic" begin
58-
sol = solve(probQ_noad, RKChebyshevDescent; dt = 0.1, maxiters = 1000)
59-
@test isapprox(sol.u, make_zeros(sol.u); atol = 1e-2)
60-
@test sol.retcode == ReturnCode.Success
61-
end
20+
f_autodiff = OptimizationFunction(f, ADTypes.AutoForwardDiff())
21+
prob_auto = OptimizationProblem(f_autodiff, x0, p)
6222

63-
@testset "RKAccelerated on Quadratic" begin
64-
sol = solve(probQ_noad, RKAccelerated; dt = 0.1, maxiters = 1000)
65-
@test isapprox(sol.u, make_zeros(sol.u); atol = 1e-2)
23+
for opt in (ODEGradientDescent(), RKChebyshevDescent(), RKAccelerated(), PRKChebyshevDescent())
24+
sol = solve(prob_auto, opt; η=0.01,dt=0.01, tmax=1000, maxiters=50_000)
25+
@test sol.u [0.0, 0.0] atol=1e-2
26+
@test sol.objective 0.0 atol=1e-2
6627
@test sol.retcode == ReturnCode.Success
6728
end
6829

30+
f_manual = OptimizationFunction(f, SciMLBase.NoAD(); grad=g!)
31+
prob_manual = OptimizationProblem(f_manual, x0)
6932

70-
@testset "PRKChebyshevDescent on Quadratic" begin
71-
sol = solve(probQ_noad, PRKChebyshevDescent; dt = 0.1, maxiters = 1000)
72-
@test isapprox(sol.u, make_zeros(sol.u); atol = 1e-2)
33+
for opt in (ODEGradientDescent(), RKChebyshevDescent(), RKAccelerated(), PRKChebyshevDescent())
34+
sol = solve(prob_manual, opt; η=0.01,dt=0.01, tmax=1000, maxiters=50_000)
35+
@test sol.u [0.0, 0.0] atol=1e-2
36+
@test sol.objective 0.0 atol=1e-2
7337
@test sol.retcode == ReturnCode.Success
7438
end
7539

76-
@testset "PRKChebyshevDescent on Rosenbrock (NoAD)" begin
77-
fr = OptimizationFunction(rosen, SciMLBase.NoAD(); grad = rosen_grad!)
78-
probR = OptimizationProblem(fr, u0r, pr)
79-
sol = solve(probR, PRKChebyshevDescent; dt = 0.001, maxiters = 2000)
80-
@test isapprox(sol.u, make_ones(sol.u); atol = 1e-1)
81-
@test sol.retcode == ReturnCode.Success
40+
f_fail = OptimizationFunction(f, SciMLBase.NoAD())
41+
prob_fail = OptimizationProblem(f_fail, x0)
42+
43+
for opt in (ODEGradientDescent(), RKChebyshevDescent(), RKAccelerated(), PRKChebyshevDescent())
44+
@test_throws ErrorException solve(prob_fail, opt; η=0.01,dt=0.001, tmax=10_000.0, maxiters=20_000)
8245
end
8346

8447
end

0 commit comments

Comments
 (0)