@@ -40,17 +40,14 @@ The callback function `callback` is a function which is called after every optim
4040step. Its signature is:
4141
4242```julia
43- callback = (state, loss_val, other_args ) -> false
43+ callback = (state, loss_val) -> false
4444```
4545
4646where `state` is a `OptimizationState` and stores information for the current
4747iteration of the solver and `loss_val` is loss/objective value. For more
4848information about the fields of the `state` look at the `OptimizationState`
49- documentation. The `other_args` can be the extra things returned from the
50- optimization `f`. This allows for saving values from the optimization and
51- using them for plotting and display without recalculating. The callback should
52- return a Boolean value, and the default should be `false`, such that the
53- optimization gets stopped if it returns `true`.
49+ documentation. The callback should return a Boolean value, and the default
50+ should be `false`, such that the optimization gets stopped if it returns `true`.
5451
5552### Callback Example
5653
@@ -67,10 +64,11 @@ function loss(u, p)
6764 sum(abs2, batch .- pred), pred
6865end
6966
70- callback = function (state, l, pred ; doplot = false) #callback function to observe training
67+ callback = function (state, l; doplot = false) #callback function to observe training
7168 display(l)
7269 # plot current prediction against data
7370 if doplot
71+ pred = predict(state.u)
7472 pl = scatter(t, ode_data[1, :], label = "data")
7573 scatter!(pl, t, pred[1, :], label = "prediction")
7674 display(plot(pl))
0 commit comments