Skip to content

Commit 4c78a07

Browse files
Updates
1 parent cf7a50e commit 4c78a07

File tree

1 file changed

+48
-46
lines changed

1 file changed

+48
-46
lines changed

lib/OptimizationLBFGS/test/runtests.jl

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,51 @@ using MLUtils
77
using LBFGSB
88
using Test
99

10-
x0 = zeros(2)
11-
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
12-
l1 = rosenbrock(x0)
13-
14-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff())
15-
prob = OptimizationProblem(optf, x0)
16-
@time res = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 100)
17-
@test res.retcode == ReturnCode.Success
18-
19-
prob = OptimizationProblem(optf, x0, lb = [-1.0, -1.0], ub = [1.0, 1.0])
20-
@time res = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 100)
21-
@test res.retcode == ReturnCode.Success
22-
23-
function con2_c(res, x, p)
24-
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1]) - 5]
25-
end
26-
27-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = con2_c)
28-
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
29-
ucons = [1.0, 0.0], lb = [-1.0, -1.0],
30-
ub = [1.0, 1.0])
31-
@time res = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 100)
32-
@test res.retcode == SciMLBase.ReturnCode.Success
33-
34-
x0 = (-pi):0.001:pi
35-
y0 = sin.(x0)
36-
data = MLUtils.DataLoader((x0, y0), batchsize = 126)
37-
function loss(coeffs, data)
38-
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
39-
return sum(abs2, ypred .- data[2])
40-
end
41-
42-
function cons1(res, coeffs, p = nothing)
43-
res[1] = coeffs[1] * coeffs[5] - 1
44-
return nothing
45-
end
46-
47-
optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1)
48-
callback = (st, l) -> (@show l; return false)
49-
50-
initpars = rand(5)
51-
l0 = optf(initpars, (x0, y0))
52-
prob = OptimizationProblem(optf, initpars, (x0, y0), lcons = [-Inf], ucons = [0.5],
53-
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
54-
opt1 = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 1000, callback = callback)
55-
@test opt1.objective < l0
10+
@testset "OptimizationLBFGS.jl" begin
11+
x0 = zeros(2)
12+
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
13+
l1 = rosenbrock(x0)
14+
15+
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff())
16+
prob = OptimizationProblem(optf, x0)
17+
@time res = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 100)
18+
@test res.retcode == ReturnCode.Success
19+
20+
prob = OptimizationProblem(optf, x0, lb = [-1.0, -1.0], ub = [1.0, 1.0])
21+
@time res = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 100)
22+
@test res.retcode == ReturnCode.Success
23+
24+
function con2_c(res, x, p)
25+
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1]) - 5]
26+
end
27+
28+
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = con2_c)
29+
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
30+
ucons = [1.0, 0.0], lb = [-1.0, -1.0],
31+
ub = [1.0, 1.0])
32+
@time res = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 100)
33+
@test res.retcode == SciMLBase.ReturnCode.Success
34+
35+
x0 = (-pi):0.001:pi
36+
y0 = sin.(x0)
37+
data = MLUtils.DataLoader((x0, y0), batchsize = 126)
38+
function loss(coeffs, data)
39+
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
40+
return sum(abs2, ypred .- data[2])
41+
end
42+
43+
function cons1(res, coeffs, p = nothing)
44+
res[1] = coeffs[1] * coeffs[5] - 1
45+
return nothing
46+
end
47+
48+
optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1)
49+
callback = (st, l) -> (@show l; return false)
50+
51+
initpars = rand(5)
52+
l0 = optf(initpars, (x0, y0))
53+
prob = OptimizationProblem(optf, initpars, (x0, y0), lcons = [-Inf], ucons = [0.5],
54+
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
55+
opt1 = solve(prob, OptimizationLBFGS.LBFGS(), maxiters = 1000, callback = callback)
56+
@test opt1.objective < l0
57+
end

0 commit comments

Comments
 (0)