Skip to content

Commit b0fc056

Browse files
Update solve.jl callback docstring
1 parent 7c017b1 commit b0fc056

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/solve.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,14 @@ The callback function `callback` is a function which is called after every optim
4040
step. Its signature is:
4141
4242
```julia
43-
callback = (state, loss_val, other_args) -> false
43+
callback = (state, loss_val) -> false
4444
```
4545
4646
where `state` is a `OptimizationState` and stores information for the current
4747
iteration of the solver and `loss_val` is loss/objective value. For more
4848
information about the fields of the `state` look at the `OptimizationState`
49-
documentation. The `other_args` can be the extra things returned from the
50-
optimization `f`. This allows for saving values from the optimization and
51-
using them for plotting and display without recalculating. The callback should
52-
return a Boolean value, and the default should be `false`, such that the
53-
optimization gets stopped if it returns `true`.
49+
documentation. The callback should return a Boolean value, and the default
50+
should be `false`, such that the optimization gets stopped if it returns `true`.
5451
5552
### Callback Example
5653
@@ -67,10 +64,11 @@ function loss(u, p)
6764
sum(abs2, batch .- pred), pred
6865
end
6966
70-
callback = function (state, l, pred; doplot = false) #callback function to observe training
67+
callback = function (state, l; doplot = false) #callback function to observe training
7168
display(l)
7269
# plot current prediction against data
7370
if doplot
71+
pred = predict(state.u)
7472
pl = scatter(t, ode_data[1, :], label = "data")
7573
scatter!(pl, t, pred[1, :], label = "prediction")
7674
display(plot(pl))

0 commit comments

Comments
 (0)