Skip to content

Commit 7c778ae

Browse files
Simplify calls to Optim.jl
Much of the complexity in the issues with the Optim.jl wrapper is simply because it doesn't treat Optim well. It makes things always complete, instead of simplifying the call. This simplifies the call, so the less you use the less machinery is required. In particular: * TwiceDifferentiable is only made and used if the optimizer needs to use Hessians. This fixes #859, fixes #893 * Only uses constraints and bounds when the user sets them. Fixes #863 and fixes #558
1 parent b4b1a07 commit 7c778ae

File tree

1 file changed

+39
-16
lines changed

1 file changed

+39
-16
lines changed

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,22 @@ function SciMLBase.__solve(cache::OptimizationCache{
393393
end
394394
end
395395
u0_type = eltype(cache.u0)
396-
optim_f = Optim.TwiceDifferentiable(_loss, gg, fg!, hh, cache.u0,
397-
real(zero(u0_type)),
398-
Optim.NLSolversBase.alloc_DF(cache.u0,
399-
real(zero(u0_type))),
400-
isnothing(cache.f.hess_prototype) ?
401-
Optim.NLSolversBase.alloc_H(cache.u0,
402-
real(zero(u0_type))) :
403-
convert.(u0_type, cache.f.hess_prototype))
396+
397+
optim_f = if SciMLBase.requireshessian(cache.opt)
398+
Optim.TwiceDifferentiable(_loss, gg, fg!, hh, cache.u0,
399+
real(zero(u0_type)),
400+
Optim.NLSolversBase.alloc_DF(cache.u0,
401+
real(zero(u0_type))),
402+
isnothing(cache.f.hess_prototype) ?
403+
Optim.NLSolversBase.alloc_H(cache.u0,
404+
real(zero(u0_type))) :
405+
convert.(u0_type, cache.f.hess_prototype))
406+
else
407+
Optim.OnceDifferentiable(_loss, gg, fg!, cache.u0,
408+
real(zero(u0_type)),
409+
Optim.NLSolversBase.alloc_DF(cache.u0,
410+
real(zero(u0_type))))
411+
end
404412

405413
cons_hl! = function (h, θ, λ)
406414
res = [similar(h) for i in 1:length(λ)]
@@ -412,15 +420,26 @@ function SciMLBase.__solve(cache::OptimizationCache{
412420

413421
lb = cache.lb === nothing ? [] : cache.lb
414422
ub = cache.ub === nothing ? [] : cache.ub
415-
if cache.f.cons !== nothing
416-
optim_fc = Optim.TwiceDifferentiableConstraints(cache.f.cons, cache.f.cons_j,
417-
cons_hl!,
418-
lb, ub,
419-
cache.lcons, cache.ucons)
423+
424+
optim_fc = if SciMLBase.requireshessian(opt)
425+
if cache.f.cons !== nothing
426+
Optim.TwiceDifferentiableConstraints(cache.f.cons, cache.f.cons_j,
427+
cons_hl!,
428+
lb, ub,
429+
cache.lcons, cache.ucons)
430+
else
431+
Optim.TwiceDifferentiableConstraints(lb, ub)
432+
end
420433
else
421-
optim_fc = Optim.TwiceDifferentiableConstraints(lb, ub)
434+
if cache.f.cons !== nothing
435+
Optim.OnceDifferentiableConstraints(cache.f.cons, cache.f.cons_j
436+
lb, ub,
437+
cache.lcons, cache.ucons)
438+
else
439+
Optim.OnceDifferentiableConstraints(lb, ub)
440+
end
422441
end
423-
442+
424443
opt_args = __map_optimizer_args(cache, cache.opt, callback = _cb,
425444
maxiters = cache.solver_args.maxiters,
426445
maxtime = cache.solver_args.maxtime,
@@ -429,7 +448,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
429448
cache.solver_args...)
430449

431450
t0 = time()
432-
opt_res = Optim.optimize(optim_f, optim_fc, cache.u0, cache.opt, opt_args)
451+
if lb === nothing && ub === nothing && cache.f.cons === nothing
452+
opt_res = Optim.optimize(optim_f, cache.u0, cache.opt, opt_args)
453+
else
454+
opt_res = Optim.optimize(optim_f, optim_fc, cache.u0, cache.opt, opt_args)
455+
end
433456
t1 = time()
434457
opt_ret = Symbol(Optim.converged(opt_res))
435458
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,

0 commit comments

Comments
 (0)