diff --git a/docs/src/optimization_packages/optimization.md b/docs/src/optimization_packages/optimization.md index ddd3bf062..f38ba9a04 100644 --- a/docs/src/optimization_packages/optimization.md +++ b/docs/src/optimization_packages/optimization.md @@ -76,17 +76,18 @@ ps_ca = ComponentArray(ps) smodel = StatefulLuxLayer{true}(model, nothing, st) function callback(state, l) - state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l + state.iter % 25 == 1 && @show "Iteration: $(state.iter), Loss: $l" return l < 1e-1 ## Terminate if loss is small end function loss(ps, data) - ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])] - return sum(abs2, ypred .- data[2]) + x_batch, y_batch = data + ypred = [smodel([x_batch[i]], ps)[1] for i in eachindex(x_batch)] + return sum(abs2, ypred .- y_batch) end optf = OptimizationFunction(loss, AutoZygote()) prob = OptimizationProblem(optf, ps_ca, data) -res = Optimization.solve(prob, Optimization.Sophia(), callback = callback) +res = Optimization.solve(prob, Optimization.Sophia(), callback = callback, epochs = 100) ```