Skip to content

Commit 2841bbe

Browse files
authored
Add constraints with ForwardDiff support and support IPNewton (#47)
* Add constraints with ForwardDiff support and support IPNewton * Fix CI failure and revert flux logging changes * Use empty arrays as default bounds * Add constaints kwargs to all AD backends to pass to constructor * Add multiple constraints with IPNewton * Add num_cons to all AD backend OptimizationFunction constructor
1 parent 85a7424 commit 2841bbe

File tree

4 files changed

+168
-48
lines changed

4 files changed

+168
-48
lines changed

src/function.jl

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@ struct AutoZygote <: AbstractADType end
88
struct AutoFiniteDiff <: AbstractADType end
99
struct AutoModelingToolkit <: AbstractADType end
1010

11-
struct OptimizationFunction{F,G,H,HV,K} <: AbstractOptimizationFunction
11+
struct OptimizationFunction{F,G,H,HV,C,CJ,CH,K} <: AbstractOptimizationFunction
1212
f::F
1313
grad::G
1414
hess::H
1515
hv::HV
1616
adtype::AbstractADType
17+
cons::C
18+
cons_j::CJ
19+
cons_h::CH
20+
num_cons::Int
1721
kwargs::K
1822
end
1923

20-
function OptimizationFunction(f, x, ::AutoForwardDiff; grad=nothing,hess=nothing, p=DiffEqBase.NullParameters(), chunksize = 1, hv = nothing, kwargs...)
24+
function OptimizationFunction(f, x, ::AutoForwardDiff; grad=nothing, hess=nothing, cons = nothing, cons_j = nothing, cons_h = nothing,
25+
num_cons = 0, p=DiffEqBase.NullParameters(), chunksize = 1, hv = nothing, kwargs...)
2126
_f = θ -> f(θ,p)[1]
2227
if grad === nothing
2328
gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}())
@@ -37,10 +42,41 @@ function OptimizationFunction(f, x, ::AutoForwardDiff; grad=nothing,hess=nothing
3742
end
3843
end
3944

40-
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(kwargs)}(f,grad,hess,hv,AutoForwardDiff(),kwargs)
45+
if cons !== nothing && cons_j === nothing
46+
if num_cons == 1
47+
cjconfig = ForwardDiff.JacobianConfig(cons, x, ForwardDiff.Chunk{chunksize}())
48+
cons_j = (res,θ) -> ForwardDiff.jacobian!(res, cons, θ, cjconfig)
49+
else
50+
cons_j = function (res, θ)
51+
for i in 1:num_cons
52+
cjconfig = ForwardDiff.JacobianConfig(x -> cons(x)[i], θ, ForwardDiff.Chunk{chunksize}())
53+
ForwardDiff.jacobian!(res[i], x -> cons(x)[i], θ, cjconfig, Val{false}())
54+
end
55+
end
56+
end
57+
end
58+
59+
if cons !== nothing && cons_h === nothing
60+
if num_cons == 1
61+
cons_h = function (res, θ)
62+
hess_config_cache = ForwardDiff.HessianConfig(cons, θ, ForwardDiff.Chunk{chunksize}())
63+
ForwardDiff.hessian!(res, cons, θ, hess_config_cache)
64+
end
65+
else
66+
cons_h = function (res, θ)
67+
for i in 1:num_cons
68+
hess_config_cache = ForwardDiff.HessianConfig(x -> cons(x)[i], θ, ForwardDiff.Chunk{chunksize}())
69+
ForwardDiff.hessian!(res[i], x -> cons(x)[i], θ, hess_config_cache, Val{false}())
70+
end
71+
end
72+
end
73+
end
74+
75+
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h),typeof(kwargs)}(f,grad,hess,hv,AutoForwardDiff(),cons,cons_j,cons_h,num_cons,kwargs)
4176
end
4277

43-
function OptimizationFunction(f, x, ::AutoZygote; grad=nothing, hess=nothing, p=DiffEqBase.NullParameters(), hv = nothing, kwargs...)
78+
function OptimizationFunction(f, x, ::AutoZygote; grad=nothing, hess=nothing, cons = nothing, cons_j = nothing, cons_h = nothing,
79+
num_cons = 0, p=DiffEqBase.NullParameters(), hv = nothing, kwargs...)
4480
_f = θ -> f(θ,p)[1]
4581
if grad === nothing
4682
grad = (res,θ) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Zygote.gradient(_f, θ)[1]) : res .= Zygote.gradient(_f, θ)[1]
@@ -68,10 +104,11 @@ function OptimizationFunction(f, x, ::AutoZygote; grad=nothing, hess=nothing, p=
68104
H .= getindex.(ForwardDiff.partials.(DiffResults.gradient(res)),1)
69105
end
70106
end
71-
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(kwargs)}(f,grad,hess,hv,AutoZygote(),kwargs)
107+
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h),typeof(kwargs)}(f,grad,hess,hv,AutoZygote(),cons,cons_j,cons_h,num_cons,kwargs)
72108
end
73109

74-
function OptimizationFunction(f, x, ::AutoReverseDiff; grad=nothing,hess=nothing, p=DiffEqBase.NullParameters(), hv = nothing, kwargs...)
110+
function OptimizationFunction(f, x, ::AutoReverseDiff; grad=nothing,hess=nothing, cons = nothing, cons_j = nothing, cons_h = nothing,
111+
num_cons = 0, p=DiffEqBase.NullParameters(), hv = nothing, kwargs...)
75112
_f = θ -> f(θ,p)[1]
76113
if grad === nothing
77114
grad = (res,θ) -> ReverseDiff.gradient!(res, _f, θ, ReverseDiff.GradientConfig(θ))
@@ -100,11 +137,12 @@ function OptimizationFunction(f, x, ::AutoReverseDiff; grad=nothing,hess=nothing
100137
end
101138
end
102139

103-
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(kwargs)}(f,grad,hess,hv,AutoReverseDiff(),kwargs)
140+
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h),typeof(kwargs)}(f,grad,hess,hv,AutoReverseDiff(),cons,cons_j,cons_h,num_cons,kwargs)
104141
end
105142

106143

107-
function OptimizationFunction(f, x, ::AutoTracker; grad=nothing,hess=nothing, p=DiffEqBase.NullParameters(), hv = nothing, kwargs...)
144+
function OptimizationFunction(f, x, ::AutoTracker; grad=nothing,hess=nothing, cons = nothing, cons_j = nothing, cons_h = nothing,
145+
num_cons = 0, p=DiffEqBase.NullParameters(), hv = nothing, kwargs...)
108146
_f = θ -> f(θ,p)[1]
109147
if grad === nothing
110148
grad = (res,θ) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Tracker.data(Tracker.gradient(_f, θ)[1])) : res .= Tracker.data(Tracker.gradient(_f, θ)[1])
@@ -119,10 +157,11 @@ function OptimizationFunction(f, x, ::AutoTracker; grad=nothing,hess=nothing, p=
119157
end
120158

121159

122-
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(kwargs)}(f,grad,hess,hv,AutoTracker(),kwargs)
160+
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h),typeof(kwargs)}(f,grad,hess,hv,AutoTracker(),cons,cons_j,cons_h,num_cons,kwargs)
123161
end
124162

125-
function OptimizationFunction(f, x, adtype::AutoFiniteDiff; grad=nothing,hess=nothing, p=DiffEqBase.NullParameters(), hv = nothing, fdtype = :forward, fdhtype = :hcentral, kwargs...)
163+
function OptimizationFunction(f, x, adtype::AutoFiniteDiff; grad=nothing,hess=nothing, cons = nothing, cons_j = nothing, cons_h = nothing,
164+
num_cons = 0, p=DiffEqBase.NullParameters(), hv = nothing, fdtype = :forward, fdhtype = :hcentral, kwargs...)
126165
_f = θ -> f(θ,p)[1]
127166
if grad === nothing
128167
grad = (res,θ) -> FiniteDiff.finite_difference_gradient!(res, _f, θ, FiniteDiff.GradientCache(res, x, Val{fdtype}))
@@ -140,5 +179,5 @@ function OptimizationFunction(f, x, adtype::AutoFiniteDiff; grad=nothing,hess=no
140179
end
141180
end
142181

143-
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(kwargs)}(f,grad,hess,hv,adtype,kwargs)
182+
return OptimizationFunction{typeof(f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h),typeof(kwargs)}(f,grad,hess,hv,adtype,cons,cons_j,cons_h,num_cons,kwargs)
144183
end

src/problem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
abstract type AbstractOptimizationProblem end
22

3-
struct OptimizationProblem{F,X,P,B,K} <: AbstractOptimizationProblem
3+
struct OptimizationProblem{F,X,P,B,LC,UC,K} <: AbstractOptimizationProblem
44
f::F
55
x::X
66
p::P
77
lb::B
88
ub::B
9+
lcons::LC
10+
ucons::UC
911
kwargs::K
10-
function OptimizationProblem(f, x; p=DiffEqBase.NullParameters(), lb = nothing, ub = nothing, kwargs...)
11-
new{typeof(f), typeof(x), typeof(p), typeof(lb), typeof(kwargs)}(f, x, p, lb, ub, kwargs)
12+
function OptimizationProblem(f, x; p=DiffEqBase.NullParameters(), lb = [], ub = [], lcons = [], ucons = [], kwargs...)
13+
new{typeof(f), typeof(x), typeof(p), typeof(lb), typeof(lcons), typeof(ucons), typeof(kwargs)}(f, x, p, lb, ub, lcons, ucons, kwargs)
1214
end
1315
end

src/solve.jl

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,35 @@ function update!(opt, xs::Flux.Zygote.Params, gs)
2626
end
2727
end
2828

29-
maybe_with_logger(f, logger) = logger === nothing ? f() : Logging.with_logger(f, logger)
30-
31-
function default_logger(logger)
32-
Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel && return nothing
33-
34-
if Sys.iswindows() || (isdefined(Main, :IJulia) && Main.IJulia.inited)
35-
progresslogger = ConsoleProgressMonitor.ProgressLogger()
36-
else
37-
progresslogger = TerminalLoggers.TerminalLogger()
38-
end
39-
40-
logger1 = LoggingExtras.EarlyFilteredLogger(progresslogger) do log
41-
log.level == ProgressLogging.ProgressLevel
42-
end
43-
logger2 = LoggingExtras.EarlyFilteredLogger(logger) do log
44-
log.level != ProgressLogging.ProgressLevel
45-
end
46-
47-
LoggingExtras.TeeLogger(logger1, logger2)
29+
maybe_with_logger(f, logger) = logger === nothing ? f() : Logging.with_logger(f, logger)
30+
31+
function default_logger(logger)
32+
Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel && return nothing
33+
if Sys.iswindows() || (isdefined(Main, :IJulia) && Main.IJulia.inited)
34+
progresslogger = ConsoleProgressMonitor.ProgressLogger()
35+
else
36+
progresslogger = TerminalLoggers.TerminalLogger()
37+
end
38+
logger1 = LoggingExtras.EarlyFilteredLogger(progresslogger) do log
39+
log.level == ProgressLogging.ProgressLevel
40+
end
41+
logger2 = LoggingExtras.EarlyFilteredLogger(logger) do log
42+
log.level != ProgressLogging.ProgressLevel
43+
end
44+
LoggingExtras.TeeLogger(logger1, logger2)
4845
end
4946

5047
macro withprogress(progress, exprs...)
51-
quote
52-
if $progress
53-
$maybe_with_logger($default_logger($Logging.current_logger())) do
54-
$ProgressLogging.@withprogress $(exprs...)
55-
end
56-
else
57-
$(exprs[end])
58-
end
59-
end |> esc
60-
end
48+
quote
49+
if $progress
50+
$maybe_with_logger($default_logger($Logging.current_logger())) do
51+
$ProgressLogging.@withprogress $(exprs...)
52+
end
53+
else
54+
$(exprs[end])
55+
end
56+
end |> esc
57+
end
6158

6259
function __solve(prob::OptimizationProblem, opt;cb = (args...) -> (false), maxiters = 1000, progress = true, save_best = true, kwargs...)
6360

@@ -224,6 +221,66 @@ function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN
224221
Optim.optimize(optim_f, prob.lb, prob.ub, prob.x, opt, Optim.Options(;extended_trace = true, callback = _cb, iterations = maxiters, kwargs...))
225222
end
226223

224+
225+
function __solve(prob::OptimizationProblem, opt::Optim.ConstrainedOptimizer;cb = (args...) -> (false), maxiters = 1000, kwargs...)
226+
local x
227+
228+
function _cb(trace)
229+
cb_call = cb(decompose_trace(trace).metadata["x"],x...)
230+
if !(typeof(cb_call) <: Bool)
231+
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
232+
end
233+
cb_call
234+
end
235+
236+
if prob.f isa OptimizationFunction
237+
_loss = function(θ)
238+
x = prob.f.f(θ, prob.p)
239+
return x[1]
240+
end
241+
fg! = function (G,θ)
242+
if G !== nothing
243+
prob.f.grad(G, θ)
244+
end
245+
return _loss(θ)
246+
end
247+
optim_f = TwiceDifferentiable(_loss, prob.f.grad, fg!, prob.f.hess, prob.x)
248+
249+
cons! = (res, θ) -> res .= prob.f.cons(θ);
250+
251+
cons_j! = function(J, x)
252+
if prob.f.num_cons > 1
253+
res = [zeros(1,size(J,2)) for i in 1:size(J,1)]
254+
prob.f.cons_j(res, x)
255+
J = vcat(res...)
256+
else
257+
prob.f.cons_j(J, x)
258+
end
259+
end
260+
261+
cons_hl! = function (h, θ, λ)
262+
if prob.f.num_cons > 1
263+
res = [similar(h) for i in 1:length(λ)]
264+
prob.f.cons_h(res, θ)
265+
h .= zeros(size(h))
266+
for i in 1:length(λ)
267+
h += λ[i]*res[i]
268+
end
269+
else
270+
prob.f.cons_h(h, θ)
271+
h += λ[1]*h
272+
end
273+
274+
end
275+
optim_fc = TwiceDifferentiableConstraints(cons!, cons_j!, cons_hl!, prob.lb, prob.ub, prob.lcons, prob.ucons)
276+
else
277+
error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
278+
end
279+
280+
Optim.optimize(optim_f, optim_fc, prob.x, opt, Optim.Options(;extended_trace = true, callback = _cb, iterations = maxiters, kwargs...))
281+
end
282+
283+
227284
function __init__()
228285
@require BlackBoxOptim="a134a8b2-14d6-55f6-9291-3336d3ab0209" begin
229286
decompose_trace(opt::BlackBoxOptim.OptRunController) = BlackBoxOptim.best_candidate(opt)
@@ -318,10 +375,10 @@ function __init__()
318375
NLopt.min_objective!(opt, _loss)
319376
end
320377

321-
if prob.ub !== nothing
322-
NLopt.upper_bounds!(opt, prob.ub)
378+
if length(prob.ub) > 0
379+
NLopt.upper_bounds!(opt, prob.ub)
323380
end
324-
if prob.lb !== nothing
381+
if length(prob.lb) > 0
325382
NLopt.lower_bounds!(opt, prob.lb)
326383
end
327384

test/rosenbrock.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ prob = OptimizationProblem(rosenbrock, x0)
2424
sol = solve(prob, NelderMead())
2525
@test 10*sol.minimum < l1
2626

27-
28-
optprob = OptimizationFunction(rosenbrock, x0, GalacticOptim.AutoZygote())
27+
optprob = OptimizationFunction(rosenbrock, x0, GalacticOptim.AutoForwardDiff();cons= x -> x[1]^2 + x[2]^2, num_cons = 1)
2928

3029
prob = OptimizationProblem(optprob, x0)
3130
sol = solve(prob, BFGS())
@@ -37,6 +36,29 @@ sol = solve(prob, Newton())
3736
sol = solve(prob, Optim.KrylovTrustRegion())
3837
@test 10*sol.minimum < l1
3938

39+
prob = OptimizationProblem(optprob, x0, lcons = [-Inf], ucons = [Inf])
40+
sol = solve(prob, IPNewton())
41+
@test 10*sol.minimum < l1
42+
43+
prob = OptimizationProblem(optprob, x0, lcons = [-5.0], ucons = [10.0])
44+
sol = solve(prob, IPNewton())
45+
@test 10*sol.minimum < l1
46+
47+
prob = OptimizationProblem(optprob, x0, lcons = [0.0], ucons = [0.0], lb = [-500.0,-500.0], ub=[-50.0,-50.0])
48+
sol = solve(prob, IPNewton())
49+
@test sol.minimum < l1
50+
51+
function con2_c(x)
52+
[x[1]^2 + x[2]^2, x[2]*sin(x[1])-x[1]]
53+
end
54+
55+
optprob = OptimizationFunction(rosenbrock, x0, GalacticOptim.AutoForwardDiff();cons= con2_c, num_cons = 2)
56+
prob = OptimizationProblem(optprob, x0, lcons = [-Inf,-Inf], ucons = [Inf,Inf])
57+
sol = solve(prob, IPNewton())
58+
@test 10*sol.minimum < l1
59+
60+
optprob = OptimizationFunction(rosenbrock, x0, GalacticOptim.AutoZygote())
61+
prob = OptimizationProblem(optprob, x0)
4062
sol = solve(prob, ADAM(), progress = false)
4163
@test 10*sol.minimum < l1
4264

0 commit comments

Comments
 (0)