Skip to content

Commit bf33d1d

Browse files
Merge pull request #1049 from jClugstor/solve_dispatches
Add solve, solve!, init
2 parents cfae41b + b081fb3 commit bf33d1d

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

lib/OptimizationBase/src/solve.jl

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,147 @@ function Base.showerror(io::IO, e::OptimizerMissingError)
2929
print(e.alg)
3030
end
3131

32+
"""
33+
```julia
34+
solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm,
35+
args...; kwargs...)::OptimizationSolution
36+
```
37+
38+
For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref)
39+
40+
## Keyword Arguments
41+
42+
The arguments to `solve` are common across all of the optimizers.
43+
These common arguments are:
44+
45+
- `maxiters`: the maximum number of iterations
46+
- `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for
47+
- `abstol`: absolute tolerance in changes of the objective value
48+
- `reltol`: relative tolerance in changes of the objective value
49+
- `callback`: a callback function
50+
51+
Some optimizer algorithms have special keyword arguments documented in the
52+
solver portion of the documentation and their respective documentation.
53+
These arguments can be passed as `kwargs...` to `solve`. Similarly, the special
54+
keyword arguments for the `local_method` of a global optimizer are passed as a
55+
`NamedTuple` to `local_options`.
56+
57+
Over time, we hope to cover more of these keyword arguments under the common interface.
58+
59+
A warning will be shown if a common argument is not implemented for an optimizer.
60+
61+
## Callback Functions
62+
63+
The callback function `callback` is a function that is called after every optimizer
64+
step. Its signature is:
65+
66+
```julia
67+
callback = (state, loss_val) -> false
68+
```
69+
70+
where `state` is an `OptimizationState` and stores information for the current
71+
iteration of the solver and `loss_val` is loss/objective value. For more
72+
information about the fields of the `state` look at the `OptimizationState`
73+
documentation. The callback should return a Boolean value, and the default
74+
should be `false`, so the optimization stops if it returns `true`.
75+
76+
### Callback Example
77+
78+
Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
79+
For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
80+
So we call the `predict` function within the callback again.
81+
82+
```julia
83+
function predict(u)
84+
Array(solve(prob, Tsit5(), p = u))
85+
end
86+
87+
function loss(u, p)
88+
pred = predict(u)
89+
sum(abs2, batch .- pred)
90+
end
91+
92+
callback = function (state, l; doplot = false) #callback function to observe training
93+
display(l)
94+
# plot current prediction against data
95+
if doplot
96+
pred = predict(state.u)
97+
pl = scatter(t, ode_data[1, :], label = "data")
98+
scatter!(pl, t, pred[1, :], label = "prediction")
99+
display(plot(pl))
100+
end
101+
return false
102+
end
103+
```
104+
105+
If the chosen method is a global optimizer that employs a local optimization
106+
method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG`
107+
from NLopt for an example. The common local optimizer arguments are:
108+
109+
- `local_method`: optimizer used for local optimization in global method
110+
- `local_maxiters`: the maximum number of iterations
111+
- `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for
112+
- `local_abstol`: absolute tolerance in changes of the objective value
113+
- `local_reltol`: relative tolerance in changes of the objective value
114+
- `local_options`: `NamedTuple` of keyword arguments for local optimizer
115+
"""
116+
function solve(prob::OptimizationProblem, alg, args...;
117+
kwargs...)::AbstractOptimizationSolution
118+
if supports_opt_cache_interface(alg)
119+
solve!(init(prob, alg, args...; kwargs...))
120+
else
121+
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
122+
throw(NonConcreteEltypeError(eltype(prob.u0)))
123+
end
124+
_check_opt_alg(prob, alg; kwargs...)
125+
__solve(prob, alg, args...; kwargs...)
126+
end
127+
end
128+
129+
"""
130+
```julia
131+
init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...)
132+
```
133+
134+
## Keyword Arguments
135+
136+
The arguments to `init` are the same as to `solve` and common across all of the optimizers.
137+
These common arguments are:
138+
139+
- `maxiters` (the maximum number of iterations)
140+
- `maxtime` (the maximum of time the optimization runs for)
141+
- `abstol` (absolute tolerance in changes of the objective value)
142+
- `reltol` (relative tolerance in changes of the objective value)
143+
- `callback` (a callback function)
144+
145+
Some optimizer algorithms have special keyword arguments documented in the
146+
solver portion of the documentation and their respective documentation.
147+
These arguments can be passed as `kwargs...` to `init`.
148+
149+
See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
150+
"""
151+
function init(prob::OptimizationProblem, alg, args...; kwargs...)::AbstractOptimizationCache
152+
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
153+
throw(NonConcreteEltypeError(eltype(prob.u0)))
154+
end
155+
_check_opt_alg(prob::OptimizationProblem, alg; kwargs...)
156+
cache = __init(prob, alg, args...; prob.kwargs..., kwargs...)
157+
return cache
158+
end
159+
160+
"""
161+
```julia
162+
solve!(cache::AbstractOptimizationCache)
163+
```
164+
165+
Solves the given optimization cache.
166+
167+
See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
168+
"""
169+
function solve!(cache::AbstractOptimizationCache)::AbstractOptimizationSolution
170+
__solve(cache)
171+
end
172+
32173
# Algorithm compatibility checking function
33174
function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
34175
!SciMLBase.allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&

0 commit comments

Comments
 (0)