Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,14 @@ The callback function `callback` is a function which is called after every optim
step. Its signature is:

```julia
callback = (state, loss_val, other_args) -> false
callback = (state, loss_val) -> false
```

where `state` is a `OptimizationState` and stores information for the current
iteration of the solver and `loss_val` is loss/objective value. For more
information about the fields of the `state` look at the `OptimizationState`
documentation. The `other_args` can be the extra things returned from the
optimization `f`. This allows for saving values from the optimization and
using them for plotting and display without recalculating. The callback should
return a Boolean value, and the default should be `false`, such that the
optimization gets stopped if it returns `true`.
documentation. The callback should return a Boolean value, and the default
should be `false`, such that the optimization gets stopped if it returns `true`.

### Callback Example

Expand All @@ -67,10 +64,11 @@ function loss(u, p)
sum(abs2, batch .- pred), pred
end

callback = function (state, l, pred; doplot = false) #callback function to observe training
callback = function (state, l; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
pred = predict(state.u)
pl = scatter(t, ode_data[1, :], label = "data")
scatter!(pl, t, pred[1, :], label = "prediction")
display(plot(pl))
Expand Down
Loading