Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions lib/OptimizationBase/src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@ using Reexport
@reexport using SciMLBase, ADTypes

using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
import SciMLBase: OptimizationProblem,
import SciMLBase: solve, init, solve!, __init, __solve,
OptimizationProblem,
OptimizationFunction, ObjSense,
MaxSense, MinSense, OptimizationStats
MaxSense, MinSense, OptimizationStats,
allowsbounds, requiresbounds,
allowsconstraints, requiresconstraints,
allowscallback, requiresgradient,
requireshessian, requiresconsjac,
requiresconshess, supports_opt_cache_interface
export ObjSense, MaxSense, MinSense
export allowsbounds, requiresbounds, allowsconstraints, requiresconstraints,
allowscallback, requiresgradient, requireshessian,
requiresconsjac, requiresconshess, supports_opt_cache_interface

using FastClosures

Expand All @@ -24,15 +33,14 @@ Base.length(::NullData) = 0
include("adtypes.jl")
include("symify.jl")
include("cache.jl")
include("solve.jl")
include("OptimizationDIExt.jl")
include("OptimizationDISparseExt.jl")
include("function.jl")
include("solve.jl")
include("utils.jl")
include("state.jl")

export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA,
IncompatibleOptimizerError, OptimizerMissingError, _check_opt_alg,
supports_opt_cache_interface
export solve, OptimizationCache, DEFAULT_CALLBACK, DEFAULT_DATA
export IncompatibleOptimizerError, OptimizerMissingError

end
207 changes: 174 additions & 33 deletions lib/OptimizationBase/src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# This file contains the top level solve interface functionality moved from SciMLBase.jl
# These functions provide the core optimization solving interface

struct IncompatibleOptimizerError <: Exception
err::String
end
Expand All @@ -9,70 +6,214 @@ function Base.showerror(io::IO, e::IncompatibleOptimizerError)
print(io, e.err)
end

const OPTIMIZER_MISSING_ERROR_MESSAGE = """
Optimization algorithm not found. Either the chosen algorithm is not a valid solver
choice for the `OptimizationProblem`, or the Optimization solver library is not loaded.
Make sure that you have loaded an appropriate OptimizationBase.jl solver library, for example,
`solve(prob,Optim.BFGS())` requires `using OptimizationOptimJL` and
`solve(prob,Adam())` requires `using OptimizationOptimisers`.
"""
```julia
solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm,
args...; kwargs...)::OptimizationSolution
```

For more information, see the OptimizationBase.jl documentation: <https://docs.sciml.ai/Optimization/stable/>.
"""
For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref)

struct OptimizerMissingError <: Exception
alg::Any
## Keyword Arguments

The arguments to `solve` are common across all of the optimizers.
These common arguments are:

- `maxiters`: the maximum number of iterations
- `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for
- `abstol`: absolute tolerance in changes of the objective value
- `reltol`: relative tolerance in changes of the objective value
- `callback`: a callback function

Some optimizer algorithms have special keyword arguments documented in the
solver portion of the documentation and their respective documentation.
These arguments can be passed as `kwargs...` to `solve`. Similarly, the special
keyword arguments for the `local_method` of a global optimizer are passed as a
`NamedTuple` to `local_options`.

Over time, we hope to cover more of these keyword arguments under the common interface.

A warning will be shown if a common argument is not implemented for an optimizer.

## Callback Functions

The callback function `callback` is a function that is called after every optimizer
step. Its signature is:

```julia
callback = (state, loss_val) -> false
```

where `state` is an `OptimizationState` and stores information for the current
iteration of the solver and `loss_val` is loss/objective value. For more
information about the fields of the `state` look at the `OptimizationState`
documentation. The callback should return a Boolean value, and the default
should be `false`, so the optimization stops if it returns `true`.

### Callback Example

Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
So we call the `predict` function within the callback again.

```julia
function predict(u)
Array(solve(prob, Tsit5(), p = u))
end

function Base.showerror(io::IO, e::OptimizerMissingError)
println(io, OPTIMIZER_MISSING_ERROR_MESSAGE)
print(io, "Chosen Optimizer: ")
print(e.alg)
function loss(u, p)
pred = predict(u)
sum(abs2, batch .- pred)
end

callback = function (state, l; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
pred = predict(state.u)
pl = scatter(t, ode_data[1, :], label = "data")
scatter!(pl, t, pred[1, :], label = "prediction")
display(plot(pl))
end
return false
end
```

If the chosen method is a global optimizer that employs a local optimization
method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG`
from NLopt for an example. The common local optimizer arguments are:

- `local_method`: optimizer used for local optimization in global method
- `local_maxiters`: the maximum number of iterations
- `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for
- `local_abstol`: absolute tolerance in changes of the objective value
- `local_reltol`: relative tolerance in changes of the objective value
- `local_options`: `NamedTuple` of keyword arguments for local optimizer
"""
function solve(prob::SciMLBase.OptimizationProblem, alg, args...;
kwargs...)::SciMLBase.AbstractOptimizationSolution
if supports_opt_cache_interface(alg)
solve!(init(prob, alg, args...; kwargs...))
else
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
end
_check_opt_alg(prob, alg; kwargs...)
__solve(prob, alg, args...; kwargs...)
end
end

function solve(
prob::SciMLBase.EnsembleProblem{T}, args...; kwargs...) where {T <:
SciMLBase.OptimizationProblem}
return __solve(prob, args...; kwargs...)
end

# Algorithm compatibility checking function
function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
!SciMLBase.allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
!allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support box constraints. Either remove the `lb` or `ub` bounds passed to `OptimizationProblem` or use a different algorithm."))
SciMLBase.requiresbounds(alg) && isnothing(prob.lb) &&
requiresbounds(alg) && isnothing(prob.lb) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires box constraints. Either pass `lb` and `ub` bounds to `OptimizationProblem` or use a different algorithm."))
!SciMLBase.allowsconstraints(alg) && !isnothing(prob.f.cons) &&
!allowsconstraints(alg) && !isnothing(prob.f.cons) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support constraints. Either remove the `cons` function passed to `OptimizationFunction` or use a different algorithm."))
SciMLBase.requiresconstraints(alg) && isnothing(prob.f.cons) &&
requiresconstraints(alg) && isnothing(prob.f.cons) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraints, pass them with the `cons` kwarg in `OptimizationFunction`."))
# Check that if constraints are present and the algorithm supports constraints, both lcons and ucons are provided
SciMLBase.allowsconstraints(alg) && !isnothing(prob.f.cons) &&
allowsconstraints(alg) && !isnothing(prob.f.cons) &&
(isnothing(prob.lcons) || isnothing(prob.ucons)) &&
throw(ArgumentError("Constrained optimization problem requires both `lcons` and `ucons` to be provided to OptimizationProblem. " *
"Example: OptimizationProblem(optf, u0, p; lcons=[-Inf], ucons=[0.0])"))
!SciMLBase.allowscallback(alg) && haskey(kwargs, :callback) &&
!allowscallback(alg) && haskey(kwargs, :callback) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support callbacks, remove the `callback` keyword argument from the `solve` call."))
SciMLBase.requiresgradient(alg) &&
requiresgradient(alg) &&
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires gradients, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoForwardDiff())` or pass it in with `grad` kwarg."))
SciMLBase.requireshessian(alg) &&
requireshessian(alg) &&
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires hessians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(); kwargs...)` or pass them in with `hess` kwarg."))
SciMLBase.requiresconsjac(alg) &&
requiresconsjac(alg) &&
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraint jacobians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(); kwargs...)` or pass them in with `cons` kwarg."))
SciMLBase.requiresconshess(alg) &&
requiresconshess(alg) &&
!(prob.f isa SciMLBase.AbstractOptimizationFunction) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) requires constraint hessians, hence use `OptimizationFunction` to generate them with an automatic differentiation backend e.g. `OptimizationFunction(f, AutoFiniteDiff(), AutoFiniteDiff(hess=true); kwargs...)` or pass them in with `cons` kwarg."))
return
end

# Base solver dispatch functions (these will be extended by specific solver packages)
supports_opt_cache_interface(alg) = false
const OPTIMIZER_MISSING_ERROR_MESSAGE = """
Optimization algorithm not found. Either the chosen algorithm is not a valid solver
choice for the `OptimizationProblem`, or the Optimization solver library is not loaded.
Make sure that you have loaded an appropriate Optimization.jl solver library, for example,
`solve(prob,Optim.BFGS())` requires `using OptimizationOptimJL` and
`solve(prob,Adam())` requires `using OptimizationOptimisers`.

For more information, see the Optimization.jl documentation: <https://docs.sciml.ai/Optimization/stable/>.
"""

function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
throw(OptimizerMissingError(cache.opt))
struct OptimizerMissingError <: Exception
alg::Any
end

function Base.showerror(io::IO, e::OptimizerMissingError)
println(io, OPTIMIZER_MISSING_ERROR_MESSAGE)
print(io, "Chosen Optimizer: ")
print(e.alg)
end

"""
```julia
init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...)
```

## Keyword Arguments

The arguments to `init` are the same as to `solve` and common across all of the optimizers.
These common arguments are:

- `maxiters` (the maximum number of iterations)
- `maxtime` (the maximum of time the optimization runs for)
- `abstol` (absolute tolerance in changes of the objective value)
- `reltol` (relative tolerance in changes of the objective value)
- `callback` (a callback function)

Some optimizer algorithms have special keyword arguments documented in the
solver portion of the documentation and their respective documentation.
These arguments can be passed as `kwargs...` to `init`.

See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
"""
function init(prob::SciMLBase.OptimizationProblem, alg, args...;
kwargs...)::SciMLBase.AbstractOptimizationCache
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
end
_check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
cache = __init(prob, alg, args...; prob.kwargs..., kwargs...)
return cache
end

"""
```julia
solve!(cache::AbstractOptimizationCache)
```

Solves the given optimization cache.

See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
"""
function solve!(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
__solve(cache)
end

# needs to be defined for each cache
supports_opt_cache_interface(alg) = false
function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution end
function __init(prob::SciMLBase.OptimizationProblem, alg, args...;
kwargs...)::SciMLBase.AbstractOptimizationCache
throw(OptimizerMissingError(alg))
end

# if no cache interface is supported at least the following method has to be defined
function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
throw(OptimizerMissingError(alg))
end
end
15 changes: 9 additions & 6 deletions lib/OptimizationBase/test/solver_missing_error_messages.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
using OptimizationBase, Test

import OptimizationBase: allowscallback, requiresbounds, requiresconstraints

prob = OptimizationProblem((x, p) -> sum(x), zeros(2))
@test_throws OptimizationBase.OptimizerMissingError solve(prob, nothing)

struct OptAlg end

SciMLBase.allowscallback(::OptAlg) = false
allowscallback(::OptAlg) = false
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg(),
callback = (args...) -> false)

SciMLBase.requiresbounds(::OptAlg) = true
requiresbounds(::OptAlg) = true
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg())
SciMLBase.requiresbounds(::OptAlg) = false
requiresbounds(::OptAlg) = false

prob = OptimizationProblem((x, p) -> sum(x), zeros(2), lb = [-1.0, -1.0], ub = [1.0, 1.0])
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg()) #by default allowsbounds is false

cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2])
optf = OptimizationFunction((x, p) -> sum(x), SciMLBase.NoAD(), cons = cons)
optf = OptimizationFunction((x, p) -> sum(x), NoAD(), cons = cons)
prob = OptimizationProblem(optf, zeros(2))
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg()) #by default allowsconstraints is false

SciMLBase.requiresconstraints(::OptAlg) = true
optf = OptimizationFunction((x, p) -> sum(x), SciMLBase.NoAD())
requiresconstraints(::OptAlg) = true
optf = OptimizationFunction((x, p) -> sum(x), NoAD())
prob = OptimizationProblem(optf, zeros(2))
@test_throws OptimizationBase.IncompatibleOptimizerError solve(prob, OptAlg())
Loading