Skip to content

Commit 80d8465

Browse files
Merge pull request #649 from SciML/Vaibhavdixit02-patch-4
Pass state to OptimizationOptimisers callback
2 parents f9f7dfb + 7f2f357 commit 80d8465

File tree

4 files changed

+23
-7
lines changed

4 files changed

+23
-7
lines changed

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,18 @@ function SciMLBase.__solve(cache::OptimizationCache{
6363
Optimization.@withprogress cache.progress name="Training" begin
6464
for (i, d) in enumerate(data)
6565
cache.f.grad(G, θ, d...)
66-
x = cache.f(θ, cache.p, d...)
66+
x = (cache.f(θ, cache.p, d...), state, i)
6767
cb_call = cache.callback(θ, x...)
6868
if !(cb_call isa Bool)
69-
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
69+
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
7070
elseif cb_call
7171
break
7272
end
73-
msg = @sprintf("loss: %.3g", x[1])
73+
msg = @sprintf("loss: %.3g", first(x)[1])
7474
cache.progress && ProgressLogging.@logprogress msg i/maxiters
7575

7676
if cache.solver_args.save_best
77-
if first(x) < first(min_err) #found a better solution
77+
if first(x)[1] < first(min_err)[1] #found a better solution
7878
min_opt = opt
7979
min_err = x
8080
min_θ = copy(θ)
@@ -93,7 +93,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
9393

9494
t1 = time()
9595

96-
SciMLBase.build_solution(cache, cache.opt, θ, x[1], solve_time = t1 - t0)
96+
SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], solve_time = t1 - t0)
9797
# here should be build_solution to create the output message
9898
end
9999

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,20 @@ using Zygote
5050
sol = Optimization.solve!(cache)
5151
@test sol.u[2.0] atol=1e-3
5252
end
53+
54+
@testset "callback" begin
55+
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
56+
x0 = zeros(2)
57+
_p = [1.0, 100.0]
58+
l1 = rosenbrock(x0, _p)
59+
60+
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
61+
62+
prob = OptimizationProblem(optprob, x0, _p)
63+
function callback(θ, l, state, iter)
64+
Optimisers.adjust!(state, 0.1/iter)
65+
return false
66+
end
67+
sol = solve(prob, Optimisers.Adam(0.1), maxiters = 1000, progress = false, callback = callback)
68+
end
5369
end

test/diffeqfluxtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function loss_neuralode(p)
8787
end
8888

8989
iter = 0
90-
callback = function (p, l, pred)
90+
callback = function (p, l, pred, args...)
9191
global iter
9292
iter += 1
9393

test/minibatch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function dudt_(u, p, t)
1818
ann(u, p, st)[1] .* u
1919
end
2020

21-
callback = function (p, l, pred; doplot = false) #callback function to observe training
21+
callback = function (p, l, pred, args...; doplot = false) #callback function to observe training
2222
display(l)
2323
# plot current prediction against data
2424
if doplot

0 commit comments

Comments
 (0)