Skip to content

Commit 11f85c5

Browse files
Reuse constraint evaluations
1 parent 22c7188 commit 11f85c5

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

lib/OptimizationNLopt/src/OptimizationNLopt.jl

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,29 +232,50 @@ function SciMLBase.__solve(cache::OptimizationCache{
232232
if cache.f.cons !== nothing
233233
eqinds = map((y) -> y[1] == y[2], zip(cache.lcons, cache.ucons))
234234
ineqinds = map((y) -> y[1] != y[2], zip(cache.lcons, cache.ucons))
235+
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
236+
thetacache = rand(size(cache.u0))
237+
Jthetacache = rand(size(cache.u0))
238+
Jcache = zeros(eltype(J), sum(ineqinds) + sum(eqinds), length(θ))
239+
evalcons = function (θ, ineqoreq)
240+
if thetacache != θ
241+
cache.f.cons(cons_cache, θ)
242+
thetacache = copy(θ)
243+
end
244+
if ineqoreq == :eq
245+
return @view(cons_cache[eqinds])
246+
else
247+
return @view(cons_cache[ineqinds])
248+
end
249+
end
250+
251+
evalconj = function (θ, ineqoreq)
252+
if Jthetacache != θ
253+
cache.f.cons_j(Jcache, θ)
254+
Jthetacache = copy(θ)
255+
end
256+
257+
if ineqoreq == :eq
258+
return @view(Jcache[eqinds, :])'
259+
else
260+
return @view(Jcache[ineqinds, :])'
261+
end
262+
end
263+
235264
if sum(ineqinds) > 0
236265
ineqcons = function (res, θ, J)
237-
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
238-
cache.f.cons(cons_cache, θ)
239-
res .= @view(cons_cache[ineqinds])
266+
res .= copy(evalcons(θ, :ineq))
240267
if length(J) > 0
241-
Jcache = zeros(eltype(J), sum(ineqinds) + sum(eqinds), length(θ))
242-
cache.f.cons_j(Jcache, θ)
243-
J .= @view(Jcache[ineqinds, :])'
268+
J .= copy(evalconj(θ, :ineq))
244269
end
245270
end
246271
NLopt.inequality_constraint!(
247272
opt_setup, ineqcons, [cache.solver_args.cons_tol for i in 1:sum(ineqinds)])
248273
end
249274
if sum(eqinds) > 0
250275
eqcons = function (res, θ, J)
251-
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
252-
cache.f.cons(cons_cache, θ)
253-
res .= @view(cons_cache[eqinds])
276+
res .= copy(evalcons(θ, :eq))
254277
if length(J) > 0
255-
Jcache = zeros(eltype(res), sum(eqinds) + sum(ineqinds), length(θ))
256-
cache.f.cons_j(Jcache, θ)
257-
J .= @view(Jcache[eqinds, :])'
278+
J .= copy(evalconj(θ, :eq))
258279
end
259280
end
260281
NLopt.equality_constraint!(

0 commit comments

Comments
 (0)