Skip to content

Commit 8212ea3

Browse files
Merge pull request #832 from SciML/nloptreusecons
NLopt: Reuse constraint evaluations
2 parents 38e8219 + 4fd3cae commit 8212ea3

File tree

6 files changed

+41
-17
lines changed

6 files changed

+41
-17
lines changed

NEWS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# v4 Breaking changes
22

3-
The main change in this breaking release has been the way mini-batching is handled. The data argument in the solve call and the implicit iteration of that in the callback has been removed,
4-
the stochastic solvers (Optimisers.jl and Sophia) now handle it explicitly. You would now pass in a DataLoader to OptimziationProblem as the second argument to the objective etc (p) if you
3+
The main change in this breaking release has been the way mini-batching is handled. The data argument in the solve call and the implicit iteration of that in the callback has been removed,
4+
the stochastic solvers (Optimisers.jl and Sophia) now handle it explicitly. You would now pass in a DataLoader to OptimziationProblem as the second argument to the objective etc (p) if you
55
want to do minibatching, else for full batch just pass in the full data.

lib/OptimizationNLopt/src/OptimizationNLopt.jl

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,29 +232,50 @@ function SciMLBase.__solve(cache::OptimizationCache{
232232
if cache.f.cons !== nothing
233233
eqinds = map((y) -> y[1] == y[2], zip(cache.lcons, cache.ucons))
234234
ineqinds = map((y) -> y[1] != y[2], zip(cache.lcons, cache.ucons))
235+
cons_cache = zeros(eltype(cache.u0), sum(eqinds) + sum(ineqinds))
236+
thetacache = rand(size(cache.u0))
237+
Jthetacache = rand(size(cache.u0))
238+
Jcache = zeros(eltype(cache.u0), sum(ineqinds) + sum(eqinds), length(cache.u0))
239+
evalcons = function (θ, ineqoreq)
240+
if thetacache != θ
241+
cache.f.cons(cons_cache, θ)
242+
thetacache = copy(θ)
243+
end
244+
if ineqoreq == :eq
245+
return @view(cons_cache[eqinds])
246+
else
247+
return @view(cons_cache[ineqinds])
248+
end
249+
end
250+
251+
evalconj = function (θ, ineqoreq)
252+
if Jthetacache != θ
253+
cache.f.cons_j(Jcache, θ)
254+
Jthetacache = copy(θ)
255+
end
256+
257+
if ineqoreq == :eq
258+
return @view(Jcache[eqinds, :])'
259+
else
260+
return @view(Jcache[ineqinds, :])'
261+
end
262+
end
263+
235264
if sum(ineqinds) > 0
236265
ineqcons = function (res, θ, J)
237-
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
238-
cache.f.cons(cons_cache, θ)
239-
res .= @view(cons_cache[ineqinds])
266+
res .= copy(evalcons(θ, :ineq))
240267
if length(J) > 0
241-
Jcache = zeros(eltype(J), sum(ineqinds) + sum(eqinds), length(θ))
242-
cache.f.cons_j(Jcache, θ)
243-
J .= @view(Jcache[ineqinds, :])'
268+
J .= copy(evalconj(θ, :ineq))
244269
end
245270
end
246271
NLopt.inequality_constraint!(
247272
opt_setup, ineqcons, [cache.solver_args.cons_tol for i in 1:sum(ineqinds)])
248273
end
249274
if sum(eqinds) > 0
250275
eqcons = function (res, θ, J)
251-
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
252-
cache.f.cons(cons_cache, θ)
253-
res .= @view(cons_cache[eqinds])
276+
res .= copy(evalcons(θ, :eq))
254277
if length(J) > 0
255-
Jcache = zeros(eltype(res), sum(eqinds) + sum(ineqinds), length(θ))
256-
cache.f.cons_j(Jcache, θ)
257-
J .= @view(Jcache[eqinds, :])'
278+
J .= copy(evalconj(θ, :eq))
258279
end
259280
end
260281
NLopt.equality_constraint!(

lib/OptimizationNLopt/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ using Test, Random
117117
# @test sol.retcode == ReturnCode.Success
118118
@test 10 * sol.objective < l1
119119

120+
Random.seed!(1)
120121
prob = OptimizationProblem(optprob, [0.5, 0.5], _p, lcons = [-Inf, -Inf],
121122
ucons = [0.0, 0.0], lb = [-1.0, -1.0], ub = [1.0, 1.0])
122123
sol = solve(prob, NLopt.GN_ISRES(), maxiters = 1000)

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
9393
cache.f.grad(G, θ)
9494
x = cache.f(θ)
9595
end
96-
opt_state = Optimization.OptimizationState(iter = i + (epoch-1)*length(data),
96+
opt_state = Optimization.OptimizationState(
97+
iter = i + (epoch - 1) * length(data),
9798
u = θ,
9899
objective = x[1],
99100
grad = G,

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ using Zygote
7070
end
7171

7272
@testset "Minibatching" begin
73-
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, ComponentArrays
73+
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random,
74+
ComponentArrays
7475

7576
x = rand(10000)
7677
y = sin.(x)

test/minibatch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ optprob = OptimizationProblem(optfun, pp, train_loader)
6060

6161
res1 = Optimization.solve(optprob,
6262
Optimization.Sophia(), callback = callback,
63-
maxiters = 1000)
63+
maxiters = 2000)
6464
@test 10res1.objective < l1
6565

6666
optfun = OptimizationFunction(loss_adjoint,

0 commit comments

Comments
 (0)