Skip to content

Commit 2a78d49

Browse files
Add constraints support for NLopt
1 parent 977630e commit 2a78d49

File tree

3 files changed

+106
-12
lines changed

3 files changed

+106
-12
lines changed

lib/OptimizationNLopt/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.2.2"
66
[deps]
77
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
88
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
9+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1011

1112
[compat]

lib/OptimizationNLopt/src/OptimizationNLopt.jl

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ function SciMLBase.requiresconsjac(opt::NLopt.Algorithm) #https://github.com/Jul
3636
end
3737
end
3838

39+
function SciMLBase.allowsconstraints(opt::NLopt.Algorithm)
40+
str_opt = string(opt)
41+
if occursin("AUGLAG", str_opt) || occursin("CCSA", str_opt) || occursin("MMA", str_opt) || occursin("COBYLA", str_opt) || occursin("ISRES", str_opt) || occursin("AGS", str_opt) || occursin("ORIG_DIRECT", str_opt) || occursin("SLSQP", str_opt)
42+
return true
43+
else
44+
return false
45+
end
46+
end
47+
48+
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::NLopt.Algorithm,
49+
data = Optimization.DEFAULT_DATA; cons_tol = 1e-6,
50+
callback = (args...) -> (false),
51+
progress = false, kwargs...)
52+
return OptimizationCache(prob, opt, data; cons_tol, callback, progress,
53+
kwargs...)
54+
end
55+
56+
3957
function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
4058
callback = nothing,
4159
maxiters::Union{Number, Nothing} = nothing,
@@ -76,7 +94,9 @@ function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
7694

7795
# add optimiser options from kwargs
7896
for j in kwargs
79-
eval(Meta.parse("NLopt." * string(j.first) * "!"))(opt, j.second)
97+
if j.first != :cons_tol
98+
eval(Meta.parse("NLopt." * string(j.first) * "!"))(opt, j.second)
99+
end
80100
end
81101

82102
if cache.ub !== nothing
@@ -170,14 +190,18 @@ function SciMLBase.__solve(cache::OptimizationCache{
170190
return x[1]
171191
end
172192

173-
fg! = function (θ, G)
174-
if length(G) > 0
175-
cache.f.grad(G, θ)
193+
if !hasfield(typeof(cache.f), :fg) || cache.f.fg === nothing
194+
fg! = function (θ, G)
195+
if length(G) > 0
196+
cache.f.grad(G, θ)
197+
end
198+
return _loss(θ)
176199
end
177-
178-
return _loss(θ)
200+
else
201+
fg! = cache.f.fg
179202
end
180203

204+
181205
opt_setup = if isa(cache.opt, NLopt.Opt)
182206
if ndims(cache.opt) != length(cache.u0)
183207
error("Passed NLopt.Opt optimization dimension does not match OptimizationProblem dimension.")
@@ -193,6 +217,37 @@ function SciMLBase.__solve(cache::OptimizationCache{
193217
NLopt.min_objective!(opt_setup, fg!)
194218
end
195219

220+
if cache.f.cons !== nothing
221+
eqinds = map((y) -> y[1]==y[2], zip(cache.lcons, cache.ucons))
222+
ineqinds = map((y) -> y[1]!=y[2], zip(cache.lcons, cache.ucons))
223+
if sum(ineqinds) > 0
224+
ineqcons = function (res, θ, J)
225+
cons_cache = zeros(eltype(res), sum(eqinds)+sum(ineqinds))
226+
cache.f.cons(cons_cache, θ)
227+
res .= @view(cons_cache[ineqinds])
228+
if length(J) > 0
229+
Jcache = zeros(eltype(J), sum(ineqinds)+sum(eqinds), length(θ))
230+
cache.f.cons_j(Jcache, θ)
231+
J .= @view(Jcache[ineqinds, :])'
232+
end
233+
end
234+
NLopt.inequality_constraint!(opt_setup, ineqcons, [cache.solver_args.cons_tol for i in 1:sum(ineqinds)])
235+
end
236+
if sum(eqinds) > 0
237+
eqcons = function (res, θ, J)
238+
cons_cache = zeros(eltype(res), sum(eqinds)+sum(ineqinds))
239+
cache.f.cons(cons_cache, θ)
240+
res .= @view(cons_cache[eqinds])
241+
if length(J) > 0
242+
Jcache = zeros(eltype(res), sum(eqinds)+sum(ineqinds), length(θ))
243+
cache.f.cons_j(Jcache, θ)
244+
J .= @view(Jcache[eqinds, :])'
245+
end
246+
end
247+
NLopt.equality_constraint!(opt_setup, eqcons, [cache.solver_args.cons_tol for i in 1:sum(eqinds)])
248+
end
249+
end
250+
196251
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
197252
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
198253

lib/OptimizationNLopt/test/runtests.jl

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using OptimizationNLopt, Optimization, Zygote
2-
using Test
2+
using Test, Random
33

44
@testset "OptimizationNLopt.jl" begin
55
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
@@ -16,7 +16,7 @@ using Test
1616
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
1717
prob = OptimizationProblem(optprob, x0, _p)
1818

19-
sol = solve(prob, NLopt.Opt(:LN_BOBYQA, 2))
19+
sol = solve(prob, NLopt.Opt(:LD_LBFGS, 2))
2020
@test sol.retcode == ReturnCode.Success
2121
@test 10 * sol.objective < l1
2222

@@ -26,10 +26,6 @@ using Test
2626
@test sol.retcode == ReturnCode.Success
2727
@test 10 * sol.objective < l1
2828

29-
sol = solve(prob, NLopt.Opt(:LD_LBFGS, 2))
30-
@test sol.retcode == ReturnCode.Success
31-
@test 10 * sol.objective < l1
32-
3329
sol = solve(prob, NLopt.Opt(:G_MLSL_LDS, 2), local_method = NLopt.Opt(:LD_LBFGS, 2),
3430
maxiters = 10000)
3531
@test sol.retcode == ReturnCode.MaxIters
@@ -82,4 +78,46 @@ using Test
8278
#nlopt gives the last best not the one where callback stops
8379
@test sol.objective < 0.8
8480
end
81+
82+
@testset "constrained" begin
83+
cons = (res, x, p) -> res .= [x[1]^2 + x[2]^2 - 1.0]
84+
x0 = zeros(2)
85+
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote();
86+
cons = cons)
87+
prob = OptimizationProblem(optprob, x0, _p, lcons = [0.0], ucons = [0.0])
88+
sol = solve(prob, NLopt.LN_COBYLA())
89+
@test sol.retcode == ReturnCode.Success
90+
@test 10 * sol.objective < l1
91+
92+
Random.seed!(1)
93+
prob = OptimizationProblem(optprob, rand(2), _p,
94+
lcons = [0.0], ucons = [0.0])
95+
96+
sol = solve(prob, NLopt.LD_SLSQP())
97+
@test sol.retcode == ReturnCode.Success
98+
@test 10 * sol.objective < l1
99+
100+
Random.seed!(1)
101+
prob = OptimizationProblem(optprob, rand(2), _p,
102+
lcons = [0.0], ucons = [0.0])
103+
sol = solve(prob, NLopt.AUGLAG(), local_method = NLopt.LD_LBFGS())
104+
@test sol.retcode == ReturnCode.Success
105+
@test 10 * sol.objective < l1
106+
107+
function con2_c(res, x, p)
108+
res .= [x[1]^2 + x[2]^2 - 1.0, x[2] * sin(x[1]) - x[1] - 2.0]
109+
end
110+
111+
optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff();cons = con2_c)
112+
Random.seed!(1)
113+
prob = OptimizationProblem(optprob, rand(2), _p, lcons = [0.0, -Inf], ucons = [0.0, 0.0])
114+
sol = solve(prob, NLopt.LD_AUGLAG(), local_method = NLopt.LD_LBFGS())
115+
@test sol.retcode == ReturnCode.Success
116+
@test 10 * sol.objective < l1
117+
118+
prob = OptimizationProblem(optprob, rand(2), _p, lcons = [-Inf, -Inf], ucons = [0.0, 0.0], lb = [-1.0, -1.0], ub = [1.0, 1.0])
119+
sol = solve(prob, NLopt.GN_ISRES(), maxiters = 1000)
120+
@test sol.retcode == ReturnCode.MaxIters
121+
@test 10 * sol.objective < l1
122+
end
85123
end

0 commit comments

Comments
 (0)