Skip to content

Commit 449d084

Browse files
Merge pull request #1053 from ChrisRackauckas-Claude/add-solve-interface
Complete solve interface migration from SciMLBase to OptimizationBase
2 parents 2bea7e8 + 2bce34b commit 449d084

File tree

3 files changed

+197
-45
lines changed

3 files changed

+197
-45
lines changed

lib/OptimizationBase/src/OptimizationBase.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,19 @@ using Reexport
55
@reexport using SciMLBase, ADTypes
66

77
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
8-
import SciMLBase: OptimizationProblem,
8+
import SciMLBase: solve, init, solve!, __init, __solve,
9+
OptimizationProblem,
910
OptimizationFunction, ObjSense,
10-
MaxSense, MinSense, OptimizationStats
11+
MaxSense, MinSense, OptimizationStats,
12+
allowsbounds, requiresbounds,
13+
allowsconstraints, requiresconstraints,
14+
allowscallback, requiresgradient,
15+
requireshessian, requiresconsjac,
16+
requiresconshess, supports_opt_cache_interface
1117
export ObjSense, MaxSense, MinSense
18+
export allowsbounds, requiresbounds, allowsconstraints, requiresconstraints,
19+
allowscallback, requiresgradient, requireshessian,
20+
requiresconsjac, requiresconshess, supports_opt_cache_interface
1221

1322
using FastClosures
1423

@@ -24,15 +33,14 @@ Base.length(::NullData) = 0
2433
include("adtypes.jl")
2534
include("symify.jl")
2635
include("cache.jl")
36+
include("solve.jl")
2737
include("OptimizationDIExt.jl")
2838
include("OptimizationDISparseExt.jl")
2939
include("function.jl")
30-
include("solve.jl")
3140
include("utils.jl")
3241
include("state.jl")
3342

34-
export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA,
35-
IncompatibleOptimizerError, OptimizerMissingError, _check_opt_alg,
36-
supports_opt_cache_interface
43+
export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA
44+
export IncompatibleOptimizerError, OptimizerMissingError
3745

3846
end

lib/OptimizationBase/src/solve.jl

Lines changed: 174 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# This file contains the top level solve interface functionality moved from SciMLBase.jl
2-
# These functions provide the core optimization solving interface
3-
41
struct IncompatibleOptimizerError <: Exception
52
err::String
63
end
@@ -9,70 +6,214 @@ function Base.showerror(io::IO, e::IncompatibleOptimizerError)
96
print(io, e.err)
107
end
118

12-
const OPTIMIZER_MISSING_ERROR_MESSAGE = """
13-
Optimization algorithm not found. Either the chosen algorithm is not a valid solver
14-
choice for the `OptimizationProblem`, or the Optimization solver library is not loaded.
15-
Make sure that you have loaded an appropriate OptimizationBase.jl solver library, for example,
16-
`solve(prob,Optim.BFGS())` requires `using OptimizationOptimJL` and
17-
`solve(prob,Adam())` requires `using OptimizationOptimisers`.
9+
"""
10+
```julia
11+
solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm,
12+
args...; kwargs...)::OptimizationSolution
13+
```
1814
19-
For more information, see the OptimizationBase.jl documentation: <https://docs.sciml.ai/Optimization/stable/>.
20-
"""
15+
For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref)
2116
22-
struct OptimizerMissingError <: Exception
23-
alg::Any
17+
## Keyword Arguments
18+
19+
The arguments to `solve` are common across all of the optimizers.
20+
These common arguments are:
21+
22+
- `maxiters`: the maximum number of iterations
23+
- `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for
24+
- `abstol`: absolute tolerance in changes of the objective value
25+
- `reltol`: relative tolerance in changes of the objective value
26+
- `callback`: a callback function
27+
28+
Some optimizer algorithms have special keyword arguments documented in the
29+
solver portion of the documentation and their respective documentation.
30+
These arguments can be passed as `kwargs...` to `solve`. Similarly, the special
31+
keyword arguments for the `local_method` of a global optimizer are passed as a
32+
`NamedTuple` to `local_options`.
33+
34+
Over time, we hope to cover more of these keyword arguments under the common interface.
35+
36+
A warning will be shown if a common argument is not implemented for an optimizer.
37+
38+
## Callback Functions
39+
40+
The callback function `callback` is a function that is called after every optimizer
41+
step. Its signature is:
42+
43+
```julia
44+
callback = (state, loss_val) -> false
45+
```
46+
47+
where `state` is an `OptimizationState` and stores information for the current
48+
iteration of the solver and `loss_val` is loss/objective value. For more
49+
information about the fields of the `state` look at the `OptimizationState`
50+
documentation. The callback should return a Boolean value, and the default
51+
should be `false`, so the optimization stops if it returns `true`.
52+
53+
### Callback Example
54+
55+
Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
56+
For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
57+
So we call the `predict` function within the callback again.
58+
59+
```julia
60+
function predict(u)
61+
Array(solve(prob, Tsit5(), p = u))
2462
end
2563
26-
function Base.showerror(io::IO, e::OptimizerMissingError)
27-
println(io, OPTIMIZER_MISSING_ERROR_MESSAGE)
28-
print(io, "Chosen Optimizer: ")
29-
print(e.alg)
64+
function loss(u, p)
65+
pred = predict(u)
66+
sum(abs2, batch .- pred)
67+
end
68+
69+
callback = function (state, l; doplot = false) #callback function to observe training
70+
display(l)
71+
# plot current prediction against data
72+
if doplot
73+
pred = predict(state.u)
74+
pl = scatter(t, ode_data[1, :], label = "data")
75+
scatter!(pl, t, pred[1, :], label = "prediction")
76+
display(plot(pl))
77+
end
78+
return false
79+
end
80+
```
81+
82+
If the chosen method is a global optimizer that employs a local optimization
83+
method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG`
84+
from NLopt for an example. The common local optimizer arguments are:
85+
86+
- `local_method`: optimizer used for local optimization in global method
87+
- `local_maxiters`: the maximum number of iterations
88+
- `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for
89+
- `local_abstol`: absolute tolerance in changes of the objective value
90+
- `local_reltol`: relative tolerance in changes of the objective value
91+
- `local_options`: `NamedTuple` of keyword arguments for local optimizer
92+
"""
93+
function solve(prob::SciMLBase.OptimizationProblem, alg, args...;
94+
kwargs...)::SciMLBase.AbstractOptimizationSolution
95+
if supports_opt_cache_interface(alg)
96+
solve!(init(prob, alg, args...; kwargs...))
97+
else
98+
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
99+
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
100+
end
101+
_check_opt_alg(prob, alg; kwargs...)
102+
__solve(prob, alg, args...; kwargs...)
103+
end
104+
end
105+
106+
function solve(
107+
prob::SciMLBase.EnsembleProblem{T}, args...; kwargs...) where {T <:
108+
SciMLBase.OptimizationProblem}
109+
return __solve(prob, args...; kwargs...)
30110
end
31111

32-
# Algorithm compatibility checking function
33112
function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
34-
!SciMLBase.allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
113+
!allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
35114
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support box constraints. Either remove the `lb` or `ub` bounds passed to `OptimizationProblem` or use a different algorithm."))
36-
SciMLBase.requiresbounds(alg) && isnothing(prob.lb) &&
115+
requiresbounds(alg) && isnothing(prob.lb) &&
37116
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires box constraints. Either pass `lb` and `ub` bounds to `OptimizationProblem` or use a different algorithm."))
38-
!SciMLBase.allowsconstraints(alg) && !isnothing(prob.f.cons) &&
117+
!allowsconstraints(alg) && !isnothing(prob.f.cons) &&
39118
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support constraints. Either remove the `cons` function passed to `OptimizationFunction` or use a different algorithm."))
40-
SciMLBase.requiresconstraints(alg) && isnothing(prob.f.cons) &&
119+
requiresconstraints(alg) && isnothing(prob.f.cons) &&
41120
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraints, pass them with the `cons` kwarg in `OptimizationFunction`."))
42121
# Check that if constraints are present and the algorithm supports constraints, both lcons and ucons are provided
43-
SciMLBase.allowsconstraints(alg) && !isnothing(prob.f.cons) &&
122+
allowsconstraints(alg) && !isnothing(prob.f.cons) &&
44123
(isnothing(prob.lcons) || isnothing(prob.ucons)) &&
45124
throw(ArgumentError("Constrained optimization problem requires both `lcons` and `ucons` to be provided to OptimizationProblem. " *
46125
"Example: OptimizationProblem(optf, u0, p; lcons=[-Inf], ucons=[0.0])"))
47-
!SciMLBase.allowscallback(alg) && haskey(kwargs, :callback) &&
126+
!allowscallback(alg) && haskey(kwargs, :callback) &&
48127
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support callbacks, remove the `callback` keyword argument from the `solve` call."))
49-
SciMLBase.requiresgradient(alg) &&
128+
requiresgradient(alg) &&
50129
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
51130
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires gradients, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoForwardDiff())` or pass it in with `grad` kwarg."))
52-
SciMLBase.requireshessian(alg) &&
131+
requireshessian(alg) &&
53132
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
54133
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires hessians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(); kwargs...)` or pass them in with `hess` kwarg."))
55-
SciMLBase.requiresconsjac(alg) &&
134+
requiresconsjac(alg) &&
56135
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
57136
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraint jacobians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(); kwargs...)` or pass them in with `cons` kwarg."))
58-
SciMLBase.requiresconshess(alg) &&
137+
requiresconshess(alg) &&
59138
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
60139
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraint hessians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(), AutoFiniteDiff(hess=true); kwargs...)` or pass them in with `cons` kwarg."))
61140
return
62141
end
63142

64-
# Base solver dispatch functions (these will be extended by specific solver packages)
65-
supports_opt_cache_interface(alg) = false
143+
const OPTIMIZER_MISSING_ERROR_MESSAGE = """
144+
Optimization algorithm not found. Either the chosen algorithm is not a valid solver
145+
choice for the `OptimizationProblem`, or the Optimization solver library is not loaded.
146+
Make sure that you have loaded an appropriate Optimization.jl solver library, for example,
147+
`solve(prob,Optim.BFGS())` requires `using OptimizationOptimJL` and
148+
`solve(prob,Adam())` requires `using OptimizationOptimisers`.
149+
150+
For more information, see the Optimization.jl documentation: <https://docs.sciml.ai/Optimization/stable/>.
151+
"""
66152

67-
function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
68-
throw(OptimizerMissingError(cache.opt))
153+
struct OptimizerMissingError <: Exception
154+
alg::Any
69155
end
70156

157+
function Base.showerror(io::IO, e::OptimizerMissingError)
158+
println(io, OPTIMIZER_MISSING_ERROR_MESSAGE)
159+
print(io, "Chosen Optimizer: ")
160+
print(e.alg)
161+
end
162+
163+
"""
164+
```julia
165+
init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...)
166+
```
167+
168+
## Keyword Arguments
169+
170+
The arguments to `init` are the same as to `solve` and common across all of the optimizers.
171+
These common arguments are:
172+
173+
- `maxiters` (the maximum number of iterations)
174+
- `maxtime` (the maximum of time the optimization runs for)
175+
- `abstol` (absolute tolerance in changes of the objective value)
176+
- `reltol` (relative tolerance in changes of the objective value)
177+
- `callback` (a callback function)
178+
179+
Some optimizer algorithms have special keyword arguments documented in the
180+
solver portion of the documentation and their respective documentation.
181+
These arguments can be passed as `kwargs...` to `init`.
182+
183+
See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
184+
"""
185+
function init(prob::SciMLBase.OptimizationProblem, alg, args...;
186+
kwargs...)::SciMLBase.AbstractOptimizationCache
187+
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
188+
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
189+
end
190+
_check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
191+
cache = __init(prob, alg, args...; prob.kwargs..., kwargs...)
192+
return cache
193+
end
194+
195+
"""
196+
```julia
197+
solve!(cache::AbstractOptimizationCache)
198+
```
199+
200+
Solves the given optimization cache.
201+
202+
See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
203+
"""
204+
function solve!(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
205+
__solve(cache)
206+
end
207+
208+
# needs to be defined for each cache
209+
supports_opt_cache_interface(alg) = false
210+
function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution end
71211
function __init(prob::SciMLBase.OptimizationProblem, alg, args...;
72212
kwargs...)::SciMLBase.AbstractOptimizationCache
73213
throw(OptimizerMissingError(alg))
74214
end
75215

216+
# if no cache interface is supported at least the following method has to be defined
76217
function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
77218
throw(OptimizerMissingError(alg))
78-
end
219+
end
Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11
using OptimizationBase, Test
2+
3+
import OptimizationBase: allowscallback, requiresbounds, requiresconstraints
4+
25
prob = OptimizationProblem((x, p) -> sum(x), zeros(2))
36
@test_throws OptimizationBase.OptimizerMissingError solve(prob, nothing)
47

58
struct OptAlg end
69

7-
SciMLBase.allowscallback(::OptAlg) = false
10+
allowscallback(::OptAlg) = false
811
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg(),
912
callback = (args...) -> false)
1013

11-
SciMLBase.requiresbounds(::OptAlg) = true
14+
requiresbounds(::OptAlg) = true
1215
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg())
13-
SciMLBase.requiresbounds(::OptAlg) = false
16+
requiresbounds(::OptAlg) = false
1417

1518
prob = OptimizationProblem((x, p) -> sum(x), zeros(2), lb = [-1.0, -1.0], ub = [1.0, 1.0])
1619
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg()) #by default allowsbounds is false
1720

1821
cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2])
19-
optf = OptimizationFunction((x, p) -> sum(x), SciMLBase.NoAD(), cons = cons)
22+
optf = OptimizationFunction((x, p) -> sum(x), NoAD(), cons = cons)
2023
prob = OptimizationProblem(optf, zeros(2))
2124
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg()) #by default allowsconstraints is false
2225

23-
SciMLBase.requiresconstraints(::OptAlg) = true
24-
optf = OptimizationFunction((x, p) -> sum(x), SciMLBase.NoAD())
26+
requiresconstraints(::OptAlg) = true
27+
optf = OptimizationFunction((x, p) -> sum(x), NoAD())
2528
prob = OptimizationProblem(optf, zeros(2))
2629
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg())

0 commit comments

Comments
 (0)