@@ -29,6 +29,147 @@ function Base.showerror(io::IO, e::OptimizerMissingError)
2929 print (e. alg)
3030end
3131
32+ """
33+ ```julia
34+ solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm,
35+ args...; kwargs...)::OptimizationSolution
36+ ```
37+
38+ For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref)
39+
40+ ## Keyword Arguments
41+
42+ The arguments to `solve` are common across all of the optimizers.
43+ These common arguments are:
44+
45+ - `maxiters`: the maximum number of iterations
46+ - `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for
47+ - `abstol`: absolute tolerance in changes of the objective value
48+ - `reltol`: relative tolerance in changes of the objective value
49+ - `callback`: a callback function
50+
51+ Some optimizer algorithms have special keyword arguments documented in the
52+ solver portion of the documentation and their respective documentation.
53+ These arguments can be passed as `kwargs...` to `solve`. Similarly, the special
54+ keyword arguments for the `local_method` of a global optimizer are passed as a
55+ `NamedTuple` to `local_options`.
56+
57+ Over time, we hope to cover more of these keyword arguments under the common interface.
58+
59+ A warning will be shown if a common argument is not implemented for an optimizer.
60+
61+ ## Callback Functions
62+
63+ The callback function `callback` is a function that is called after every optimizer
64+ step. Its signature is:
65+
66+ ```julia
67+ callback = (state, loss_val) -> false
68+ ```
69+
70+ where `state` is an `OptimizationState` and stores information for the current
71+ iteration of the solver and `loss_val` is loss/objective value. For more
72+ information about the fields of the `state` look at the `OptimizationState`
73+ documentation. The callback should return a Boolean value, and the default
74+ should be `false`, so the optimization stops if it returns `true`.
75+
76+ ### Callback Example
77+
78+ Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
79+ For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
80+ So we call the `predict` function within the callback again.
81+
82+ ```julia
83+ function predict(u)
84+ Array(solve(prob, Tsit5(), p = u))
85+ end
86+
87+ function loss(u, p)
88+ pred = predict(u)
89+ sum(abs2, batch .- pred)
90+ end
91+
92+ callback = function (state, l; doplot = false) #callback function to observe training
93+ display(l)
94+ # plot current prediction against data
95+ if doplot
96+ pred = predict(state.u)
97+ pl = scatter(t, ode_data[1, :], label = "data")
98+ scatter!(pl, t, pred[1, :], label = "prediction")
99+ display(plot(pl))
100+ end
101+ return false
102+ end
103+ ```
104+
105+ If the chosen method is a global optimizer that employs a local optimization
106+ method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG`
107+ from NLopt for an example. The common local optimizer arguments are:
108+
109+ - `local_method`: optimizer used for local optimization in global method
110+ - `local_maxiters`: the maximum number of iterations
111+ - `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for
112+ - `local_abstol`: absolute tolerance in changes of the objective value
113+ - `local_reltol`: relative tolerance in changes of the objective value
114+ - `local_options`: `NamedTuple` of keyword arguments for local optimizer
115+ """
116+ function SciMLBase. solve (prob:: OptimizationProblem , alg, args... ;
117+ kwargs... ):: SciMLBase.AbstractOptimizationSolution
118+ if supports_opt_cache_interface (alg)
119+ SciMLBase. solve! (init (prob, alg, args... ; kwargs... ))
120+ else
121+ if prob. u0 != = nothing && ! isconcretetype (eltype (prob. u0))
122+ throw (NonConcreteEltypeError (eltype (prob. u0)))
123+ end
124+ _check_opt_alg (prob, alg; kwargs... )
125+ SciMLBase. __solve (prob, alg, args... ; kwargs... )
126+ end
127+ end
128+
129+ """
130+ ```julia
131+ init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...)
132+ ```
133+
134+ ## Keyword Arguments
135+
136+ The arguments to `init` are the same as to `solve` and common across all of the optimizers.
137+ These common arguments are:
138+
139+ - `maxiters` (the maximum number of iterations)
140+ - `maxtime` (the maximum of time the optimization runs for)
141+ - `abstol` (absolute tolerance in changes of the objective value)
142+ - `reltol` (relative tolerance in changes of the objective value)
143+ - `callback` (a callback function)
144+
145+ Some optimizer algorithms have special keyword arguments documented in the
146+ solver portion of the documentation and their respective documentation.
147+ These arguments can be passed as `kwargs...` to `init`.
148+
149+ See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
150+ """
151+ function SciMLBase. init (prob:: OptimizationProblem , alg, args... ; kwargs... ):: SciMLBase.AbstractOptimizationCache
152+ if prob. u0 != = nothing && ! isconcretetype (eltype (prob. u0))
153+ throw (NonConcreteEltypeError (eltype (prob. u0)))
154+ end
155+ _check_opt_alg (prob:: OptimizationProblem , alg; kwargs... )
156+ cache = SciMLBase. __init (prob, alg, args... ; prob. kwargs... , kwargs... )
157+ return cache
158+ end
159+
160+ """
161+ ```julia
162+ solve!(cache::AbstractOptimizationCache)
163+ ```
164+
165+ Solves the given optimization cache.
166+
167+ See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
168+ """
169+ function SciMLBase. solve! (cache:: SciMLBase.AbstractOptimizationCache ):: SciMLBase.AbstractOptimizationSolution
170+ SciMLBase. __solve (cache)
171+ end
172+
32173# Algorithm compatibility checking function
33174function _check_opt_alg (prob:: SciMLBase.OptimizationProblem , alg; kwargs... )
34175 ! SciMLBase. allowsbounds (alg) && (! isnothing (prob. lb) || ! isnothing (prob. ub)) &&
64205# Base solver dispatch functions (these will be extended by specific solver packages)
65206supports_opt_cache_interface (alg) = false
66207
67- function __solve (cache:: SciMLBase.AbstractOptimizationCache ):: SciMLBase.AbstractOptimizationSolution
208+ function SciMLBase . __solve (cache:: SciMLBase.AbstractOptimizationCache ):: SciMLBase.AbstractOptimizationSolution
68209 throw (OptimizerMissingError (cache. opt))
69210end
70211
71- function __init (prob:: SciMLBase.OptimizationProblem , alg, args... ;
212+ function SciMLBase . __init (prob:: SciMLBase.OptimizationProblem , alg, args... ;
72213 kwargs... ):: SciMLBase.AbstractOptimizationCache
73214 throw (OptimizerMissingError (alg))
74215end
75216
76- function __solve (prob:: SciMLBase.OptimizationProblem , alg, args... ; kwargs... )
217+ function SciMLBase . __solve (prob:: SciMLBase.OptimizationProblem , alg, args... ; kwargs... )
77218 throw (OptimizerMissingError (alg))
78219end
0 commit comments