diff --git a/src/solve.jl b/src/solve.jl index a3a16d61c..cf4b2fd50 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -40,17 +40,14 @@ The callback function `callback` is a function which is called after every optim step. Its signature is: ```julia -callback = (state, loss_val, other_args) -> false +callback = (state, loss_val) -> false ``` where `state` is a `OptimizationState` and stores information for the current iteration of the solver and `loss_val` is loss/objective value. For more information about the fields of the `state` look at the `OptimizationState` -documentation. The `other_args` can be the extra things returned from the -optimization `f`. This allows for saving values from the optimization and -using them for plotting and display without recalculating. The callback should -return a Boolean value, and the default should be `false`, such that the -optimization gets stopped if it returns `true`. +documentation. The callback should return a Boolean value, and the default +should be `false`, such that the optimization gets stopped if it returns `true`. ### Callback Example @@ -67,10 +64,11 @@ function loss(u, p) sum(abs2, batch .- pred), pred end -callback = function (state, l, pred; doplot = false) #callback function to observe training +callback = function (state, l; doplot = false) #callback function to observe training display(l) # plot current prediction against data if doplot + pred = predict(state.u) pl = scatter(t, ode_data[1, :], label = "data") scatter!(pl, t, pred[1, :], label = "prediction") display(plot(pl))