@@ -32,27 +32,28 @@ keyword arguments for the `local_method` of a global optimizer are passed as a
3232
3333Over time, we hope to cover more of these keyword arguments under the common interface.
3434
35- If a common argument is not implemented for a optimizer, a warning will be shown .
35+ A warning will be shown if a common argument is not implemented for an optimizer.
3636
3737## Callback Functions
3838
39- The callback function `callback` is a function which is called after every optimizer
39+ The callback function `callback` is a function that is called after every optimizer
4040step. Its signature is:
4141
4242```julia
4343callback = (state, loss_val) -> false
4444```
4545
46- where `state` is a `OptimizationState` and stores information for the current
46+ where `state` is an `OptimizationState` and stores information for the current
4747iteration of the solver and `loss_val` is loss/objective value. For more
4848information about the fields of the `state` look at the `OptimizationState`
4949documentation. The callback should return a Boolean value, and the default
50- should be `false`, such that the optimization gets stopped if it returns `true`.
50+ should be `false`, so the optimization stops if it returns `true`.
5151
5252### Callback Example
5353
54- Here we show an example a callback function that plots the prediction at the current value of the optimization variables.
55- The loss function here returns the loss and the prediction i.e. the solution of the `ODEProblem` `prob`, so we can use the prediction in the callback.
54+ Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
55+ For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
56+ So we call the `predict` function within the callback again.
5657
5758```julia
5859function predict(u)
6162
6263function loss(u, p)
6364 pred = predict(u)
64- sum(abs2, batch .- pred), pred
65+ sum(abs2, batch .- pred)
6566end
6667
6768callback = function (state, l; doplot = false) #callback function to observe training
0 commit comments