Skip to content

Commit a81ae86

Browse files
Create runtests.jl
Added basic tests for the gradient descent code in OptimizationODE.jl
1 parent 1d322a1 commit a81ae86

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using Test
2+
using Optimization
3+
using Optimization.SciMLBase
4+
using OptimizationODE
5+
6+
@testset "ODEGradientDescent Tests" begin
7+
8+
# Define the Rosenbrock objective and its gradient
9+
function rosen(u, p)
10+
return (p[1] - u[1])^2 + p[2] * (u[2] - u[1]^2)^2
11+
end
12+
13+
function rosen_grad!(g, u, p, data)
14+
g[1] = -2 * (p[1] - u[1]) - 4 * p[2] * u[1] * (u[2] - u[1]^2)
15+
g[2] = 2 * p[2] * (u[2] - u[1]^2)
16+
return g
17+
end
18+
19+
# Set up the problem
20+
u0 = [0.0, 0.0]
21+
p = [1.0, 100.0]
22+
23+
# Wrap into an OptimizationFunction without AD, providing our gradient
24+
f = OptimizationFunction(
25+
rosen,
26+
Optimization.SciMLBase.NoAD();
27+
grad = rosen_grad!
28+
)
29+
30+
prob = OptimizationProblem(f, u0, p)
31+
32+
# Solve with ODEGradientDescent
33+
sol = solve(
34+
prob,
35+
ODEGradientDescent();
36+
η = 0.001,
37+
tmax = 1_000.0,
38+
dt = 0.01
39+
)
40+
41+
# Assertions
42+
@test isapprox(sol.u[1], 1.0; atol = 1e-2)
43+
@test isapprox(sol.u[2], 1.0; atol = 1e-2)
44+
45+
end

0 commit comments

Comments
 (0)