Skip to content

Commit 1ff1c36

Browse files
committed
reorder fields in OptimizationCache for cleaner dispatches
Since we always dispatch on the optimizer, it makes much more sense to have this as the first argument (and first type parameter).
1 parent 6fe78dc commit 1ff1c36

File tree

19 files changed

+167
-597
lines changed

19 files changed

+167
-597
lines changed

lib/OptimizationAuglag/src/OptimizationAuglag.jl

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,32 +59,7 @@ function __map_optimizer_args(cache::OptimizationBase.OptimizationCache, opt::Au
5959
return mapped_args
6060
end
6161

62-
function SciMLBase.__solve(cache::OptimizationCache{
63-
F,
64-
RC,
65-
LB,
66-
UB,
67-
LC,
68-
UC,
69-
S,
70-
O,
71-
D,
72-
P,
73-
C
74-
}) where {
75-
F,
76-
RC,
77-
LB,
78-
UB,
79-
LC,
80-
UC,
81-
S,
82-
O <:
83-
AugLag,
84-
D,
85-
P,
86-
C
87-
}
62+
function SciMLBase.__solve(cache::OptimizationCache{O <: AugLag}) where {O}
8863
maxiters = OptimizationBase._check_and_convert_maxiters(cache.solver_args.maxiters)
8964

9065
local x
@@ -121,7 +96,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
12196
if cache.callback(opt_state, x...)
12297
error("Optimization halted by callback.")
12398
end
124-
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
99+
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
100+
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
125101
end
126102

127103
prev_eqcons = zero(λ)

lib/OptimizationBBO/src/OptimizationBBO.jl

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
module OptimizationBBO
22

33
using Reexport
4-
import OptimizationBase
5-
import OptimizationBase: SciMLBase
6-
import BlackBoxOptim
7-
import SciMLBase: MultiObjectiveOptimizationFunction
4+
using OptimizationBase
5+
using SciMLBase
6+
using BlackBoxOptim: BlackBoxOptim
87

98
abstract type BBO end
109

@@ -36,19 +35,19 @@ function decompose_trace(opt::BlackBoxOptim.OptRunController, progress)
3635
if iszero(max_time)
3736
# we stop at either convergence or max_steps
3837
n_steps = BlackBoxOptim.num_steps(opt)
39-
Base.@logmsg(Base.LogLevel(-1), msg, progress=n_steps/maxiters,
38+
Base.@logmsg(Base.LogLevel(-1), msg, progress=n_steps / maxiters,
4039
_id=:OptimizationBBO)
4140
else
4241
# we stop at either convergence or max_time
4342
elapsed = BlackBoxOptim.elapsed_time(opt)
44-
Base.@logmsg(Base.LogLevel(-1), msg, progress=elapsed/max_time,
43+
Base.@logmsg(Base.LogLevel(-1), msg, progress=elapsed / max_time,
4544
_id=:OptimizationBBO)
4645
end
4746
end
4847
return BlackBoxOptim.best_candidate(opt)
4948
end
5049

51-
function __map_optimizer_args(prob::OptimizationBase.OptimizationCache, opt::BBO;
50+
function __map_optimizer_args(prob::OptimizationCache, opt::BBO;
5251
callback = nothing,
5352
maxiters::Union{Number, Nothing} = nothing,
5453
maxtime::Union{Number, Nothing} = nothing,
@@ -96,32 +95,7 @@ function map_objective(obj::BlackBoxOptim.IndexedTupleFitness)
9695
obj.orig
9796
end
9897

99-
function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
100-
F,
101-
RC,
102-
LB,
103-
UB,
104-
LC,
105-
UC,
106-
S,
107-
O,
108-
D,
109-
P,
110-
C
111-
}) where {
112-
F,
113-
RC,
114-
LB,
115-
UB,
116-
LC,
117-
UC,
118-
S,
119-
O <:
120-
BBO,
121-
D,
122-
P,
123-
C
124-
}
98+
function SciMLBase.__solve(cache::OptimizationCache{O <: BBO}) where {O}
12599
function _cb(trace)
126100
if cache.callback === OptimizationBase.DEFAULT_CALLBACK
127101
cb_call = false

lib/OptimizationBase/src/cache.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ struct AnalysisResults{O, C}
55
constraints::C
66
end
77

8-
struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, P, C, M} <:
8+
struct OptimizationCache{O, F, RC, LB, UB, LC, UC, S, P, C, M} <:
99
SciMLBase.AbstractOptimizationCache
10+
opt::O
1011
f::F
1112
reinit_cache::RC
1213
lb::LB
1314
ub::UB
1415
lcons::LC
1516
ucons::UC
1617
sense::S
17-
opt::O
1818
progress::P
1919
callback::C
2020
manifold::M
@@ -46,13 +46,13 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
4646
prob.f.adtype isa AutoZygote) &&
4747
(SciMLBase.requireshessian(opt) || SciMLBase.requiresconshess(opt) ||
4848
SciMLBase.requireslagh(opt))
49-
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
50-
So a `SecondOrder` with $(prob.f.adtype) for both inner and outer will be created, this can be suboptimal and not work in some cases so
49+
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
50+
So a `SecondOrder` with $(prob.f.adtype) for both inner and outer will be created, this can be suboptimal and not work in some cases so
5151
an explicit `SecondOrder` ADtype is recommended."
5252
elseif prob.f.adtype isa AutoZygote &&
5353
(SciMLBase.requiresconshess(opt) || SciMLBase.requireslagh(opt) ||
5454
SciMLBase.requireshessian(opt))
55-
@warn "The selected optimization algorithm requires second order derivatives, but `AutoZygote` ADtype was provided.
55+
@warn "The selected optimization algorithm requires second order derivatives, but `AutoZygote` ADtype was provided.
5656
So a `SecondOrder` with `AutoZygote` for inner and `AutoForwardDiff` for outer will be created, for choosing another pair
5757
an explicit `SecondOrder` ADtype is recommended."
5858
end
@@ -71,9 +71,9 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
7171
cons_res = nothing
7272
end
7373

74-
return OptimizationCache(f, reinit_cache_passedon, prob.lb, prob.ub, prob.lcons,
74+
return OptimizationCache(opt, f, reinit_cache_passedon, prob.lb, prob.ub, prob.lcons,
7575
prob.ucons, prob.sense,
76-
opt, progress, callback, manifold, AnalysisResults(obj_res, cons_res),
76+
progress, callback, manifold, AnalysisResults(obj_res, cons_res),
7777
merge((; maxiters, maxtime, abstol, reltol),
7878
NamedTuple(kwargs)))
7979
end

lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ SciMLBase.requireshessian(::CMAEvolutionStrategyOpt) = false
2121
SciMLBase.requiresconsjac(::CMAEvolutionStrategyOpt) = false
2222
SciMLBase.requiresconshess(::CMAEvolutionStrategyOpt) = false
2323

24-
function __map_optimizer_args(prob::OptimizationBase.OptimizationCache, opt::CMAEvolutionStrategyOpt;
24+
function __map_optimizer_args(
25+
prob::OptimizationBase.OptimizationCache, opt::CMAEvolutionStrategyOpt;
2526
callback = nothing,
2627
maxiters::Union{Number, Nothing} = nothing,
2728
maxtime::Union{Number, Nothing} = nothing,
@@ -53,32 +54,7 @@ function __map_optimizer_args(prob::OptimizationBase.OptimizationCache, opt::CMA
5354
return mapped_args
5455
end
5556

56-
function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
57-
F,
58-
RC,
59-
LB,
60-
UB,
61-
LC,
62-
UC,
63-
S,
64-
O,
65-
D,
66-
P,
67-
C
68-
}) where {
69-
F,
70-
RC,
71-
LB,
72-
UB,
73-
LC,
74-
UC,
75-
S,
76-
O <:
77-
CMAEvolutionStrategyOpt,
78-
D,
79-
P,
80-
C
81-
}
57+
function SciMLBase.__solve(cache::OptimizationCache{O <: CMAEvolutionStrategyOpt}) where {O}
8258
local x, cur, state
8359

8460
function _cb(opt, y, fvals, perm)

lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ SciMLBase.allowsconstraints(opt::Evolutionary.AbstractOptimizer) = true
1010
SciMLBase.supports_opt_cache_interface(opt::Evolutionary.AbstractOptimizer) = true
1111
end
1212
@static if isdefined(OptimizationBase, :supports_opt_cache_interface)
13-
OptimizationBase.supports_opt_cache_interface(opt::Evolutionary.AbstractOptimizer) = true
13+
function OptimizationBase.supports_opt_cache_interface(opt::Evolutionary.AbstractOptimizer)
14+
true
15+
end
1416
end
1517
SciMLBase.requiresgradient(opt::Evolutionary.AbstractOptimizer) = false
1618
SciMLBase.requiresgradient(opt::Evolutionary.NSGA2) = false
@@ -76,32 +78,8 @@ function __map_optimizer_args(cache::OptimizationBase.OptimizationCache,
7678
return Evolutionary.Options(; mapped_args...)
7779
end
7880

79-
function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
80-
F,
81-
RC,
82-
LB,
83-
UB,
84-
LC,
85-
UC,
86-
S,
87-
O,
88-
D,
89-
P,
90-
C
91-
}) where {
92-
F,
93-
RC,
94-
LB,
95-
UB,
96-
LC,
97-
UC,
98-
S,
99-
O <:
100-
Evolutionary.AbstractOptimizer,
101-
D,
102-
P,
103-
C
104-
}
81+
function SciMLBase.__solve(cache::OptimizationCache{O <:
82+
Evolutionary.AbstractOptimizer}) where {O}
10583
local x, cur, state
10684

10785
function _cb(trace)

lib/OptimizationGCMAES/src/OptimizationGCMAES.jl

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,7 @@ function SciMLBase.__init(prob::SciMLBase.OptimizationProblem,
6161
kwargs...)
6262
end
6363

64-
function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
65-
F,
66-
RC,
67-
LB,
68-
UB,
69-
LC,
70-
UC,
71-
S,
72-
O,
73-
D,
74-
P,
75-
C
76-
}) where {
77-
F,
78-
RC,
79-
LB,
80-
UB,
81-
LC,
82-
UC,
83-
S,
84-
O <:
85-
GCMAESOpt,
86-
D,
87-
P,
88-
C
89-
}
64+
function SciMLBase.__solve(cache::OptimizationCache{O <: GCMAESOpt}) where {O}
9065
local x
9166
local G = similar(cache.u0)
9267

lib/OptimizationLBFGSB/src/OptimizationLBFGSB.jl

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -78,32 +78,7 @@ function __map_optimizer_args(cache::OptimizationBase.OptimizationCache, opt::LB
7878
return mapped_args
7979
end
8080

81-
function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
82-
F,
83-
RC,
84-
LB,
85-
UB,
86-
LC,
87-
UC,
88-
S,
89-
O,
90-
D,
91-
P,
92-
C
93-
}) where {
94-
F,
95-
RC,
96-
LB,
97-
UB,
98-
LC,
99-
UC,
100-
S,
101-
O <:
102-
LBFGSB,
103-
D,
104-
P,
105-
C
106-
}
81+
function SciMLBase.__solve(cache::OptimizationCache{O <: LBFGSB}) where {O}
10782
maxiters = OptimizationBase._check_and_convert_maxiters(cache.solver_args.maxiters)
10883

10984
local x
@@ -142,7 +117,8 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
142117
if cache.callback(opt_state, x...)
143118
error("Optimization halted by callback.")
144119
end
145-
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
120+
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
121+
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
146122
end
147123

148124
prev_eqcons = zero(λ)

0 commit comments

Comments
 (0)