Skip to content

Commit 4131720

Browse files
fix nested breaking and iteration counts
1 parent 47a2481 commit 4131720

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
7979
fevals = 0
8080
gevals = 0
8181
t0 = time()
82+
breakall = false
8283
Optimization.@withprogress cache.progress name="Training" begin
8384
for epoch in 1:epochs
85+
if breakall
86+
break
87+
end
8488
for (i, d) in enumerate(data)
8589
if cache.f.fg !== nothing && dataiterate
8690
x = cache.f.fg(G, θ, d)
@@ -111,10 +115,10 @@ function SciMLBase.__solve(cache::OptimizationCache{
111115
objective = x[1],
112116
grad = G,
113117
original = state)
114-
cb_call = cache.callback(opt_state, x...)
115-
if !(cb_call isa Bool)
118+
breakall = cache.callback(opt_state, x...)
119+
if !(breakall isa Bool)
116120
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
117-
elseif cb_call
121+
elseif breakall
118122
break
119123
end
120124
msg = @sprintf("loss: %.3g", first(x)[1])
@@ -126,7 +130,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
126130
min_err = x
127131
min_θ = copy(θ)
128132
end
129-
if i == length(data) #Last iter, revert to best.
133+
if i == length(data)*epochs #Last iter, revert to best.
130134
opt = min_opt
131135
x = min_err
132136
θ = min_θ
@@ -136,7 +140,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
136140
objective = x[1],
137141
grad = G,
138142
original = state)
139-
cache.callback(opt_state, x...)
143+
breakall = cache.callback(opt_state, x...)
140144
break
141145
end
142146
end

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ end
7676
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random,
7777
ComponentArrays
7878

79-
x = rand(10000)
79+
x = rand(Float32, 10000)
8080
y = sin.(x)
8181
data = MLUtils.DataLoader((x, y), batchsize = 100)
8282

@@ -99,13 +99,16 @@ end
9999
optf = OptimizationFunction(loss, AutoZygote())
100100
prob = OptimizationProblem(optf, ps_ca, data)
101101

102-
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)
102+
res = Optimization.solve(prob, Optimisers.Adam(), epochs = 50)
103103

104104
@test res.objective < 1e-4
105-
@test res.stats.iterations == 10000*length(data)
106-
@test res.stats.fevals == 10000*length(data)
107-
@test res.stats.gevals == 10000*length(data)
105+
@test res.stats.iterations == 50*length(data)
106+
@test res.stats.fevals == 50*length(data)
107+
@test res.stats.gevals == 50*length(data)
108+
109+
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)
108110

111+
@test res.objective < 1e-4
109112

110113
using MLDataDevices
111114
data = CPUDevice()(data)

0 commit comments

Comments
 (0)