Skip to content

Commit 89e2d0c

Browse files
committed
Rewrite the progressbar part of OptimizationOptimisers
1 parent d32d3f9 commit 89e2d0c

File tree

2 files changed

+60
-63
lines changed

2 files changed

+60
-63
lines changed

lib/OptimizationOptimisers/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ version = "0.3.12"
55
[deps]
66
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
77
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
8-
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
98
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
109
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1110
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
11+
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1212

1313
[extras]
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2020
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2121
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
22+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2223

2324
[compat]
2425
julia = "1.10"
@@ -29,4 +30,4 @@ Optimisers = "0.2, 0.3, 0.4"
2930
Reexport = "1.2"
3031

3132
[targets]
32-
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
33+
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"]

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 57 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module OptimizationOptimisers
22

3-
using Reexport, Printf, ProgressLogging
3+
using Reexport, ProgressLogging, UUIDs
44
@reexport using Optimisers, OptimizationBase
55
using SciMLBase
66

@@ -95,77 +95,73 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
9595
gevals = 0
9696
t0 = time()
9797
breakall = false
98-
begin
99-
for epoch in 1:epochs
100-
if breakall
101-
break
98+
progress_id = uuid4()
99+
for epoch in 1:epochs, d in data
100+
if cache.f.fg !== nothing && dataiterate
101+
x = cache.f.fg(G, θ, d)
102+
iterations += 1
103+
fevals += 1
104+
gevals += 1
105+
elseif dataiterate
106+
cache.f.grad(G, θ, d)
107+
x = cache.f(θ, d)
108+
iterations += 1
109+
fevals += 2
110+
gevals += 1
111+
elseif cache.f.fg !== nothing
112+
x = cache.f.fg(G, θ)
113+
iterations += 1
114+
fevals += 1
115+
gevals += 1
116+
else
117+
cache.f.grad(G, θ)
118+
x = cache.f(θ)
119+
iterations += 1
120+
fevals += 2
121+
gevals += 1
122+
end
123+
opt_state = OptimizationBase.OptimizationState(
124+
iter = iterations,
125+
u = θ,
126+
p = d,
127+
objective = x[1],
128+
grad = G,
129+
original = state)
130+
breakall = cache.callback(opt_state, x...)
131+
if !(breakall isa Bool)
132+
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
133+
elseif breakall
134+
break
135+
end
136+
cache.progress &&
137+
@info ProgressLogging.Progress(progress_id, iterations / maxiters;
138+
name = "loss: $(round(first(first(x)); digits=3))")
139+
140+
if cache.solver_args.save_best
141+
if first(x)[1] < first(min_err)[1] #found a better solution
142+
min_opt = opt
143+
min_err = x
144+
min_θ = copy(θ)
102145
end
103-
for (i, d) in enumerate(data)
104-
if cache.f.fg !== nothing && dataiterate
105-
x = cache.f.fg(G, θ, d)
106-
iterations += 1
107-
fevals += 1
108-
gevals += 1
109-
elseif dataiterate
110-
cache.f.grad(G, θ, d)
111-
x = cache.f(θ, d)
112-
iterations += 1
113-
fevals += 2
114-
gevals += 1
115-
elseif cache.f.fg !== nothing
116-
x = cache.f.fg(G, θ)
117-
iterations += 1
118-
fevals += 1
119-
gevals += 1
120-
else
121-
cache.f.grad(G, θ)
122-
x = cache.f(θ)
123-
iterations += 1
124-
fevals += 2
125-
gevals += 1
126-
end
127-
opt_state = OptimizationBase.OptimizationState(
128-
iter = i + (epoch - 1) * length(data),
146+
if iterations == length(data) * epochs #Last iter, revert to best.
147+
opt = min_opt
148+
x = min_err
149+
θ = min_θ
150+
cache.f.grad(G, θ, d)
151+
opt_state = OptimizationBase.OptimizationState(iter = iterations,
129152
u = θ,
130153
p = d,
131154
objective = x[1],
132155
grad = G,
133156
original = state)
134157
breakall = cache.callback(opt_state, x...)
135-
if !(breakall isa Bool)
136-
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
137-
elseif breakall
138-
break
139-
end
140-
msg = @sprintf("loss: %.3g", first(x)[1])
141-
cache.progress && ProgressLogging.@logprogress msg iterations/maxiters
142-
143-
if cache.solver_args.save_best
144-
if first(x)[1] < first(min_err)[1] #found a better solution
145-
min_opt = opt
146-
min_err = x
147-
min_θ = copy(θ)
148-
end
149-
if iterations == length(data) * epochs #Last iter, revert to best.
150-
opt = min_opt
151-
x = min_err
152-
θ = min_θ
153-
cache.f.grad(G, θ, d)
154-
opt_state = OptimizationBase.OptimizationState(iter = iterations,
155-
u = θ,
156-
p = d,
157-
objective = x[1],
158-
grad = G,
159-
original = state)
160-
breakall = cache.callback(opt_state, x...)
161-
break
162-
end
163-
end
164-
state, θ = Optimisers.update(state, θ, G)
158+
break
165159
end
166160
end
161+
state, θ = Optimisers.update(state, θ, G)
167162
end
168163

164+
cache.progress && @info ProgressLogging.Progress(progress_id; done = true)
169165
t1 = time()
170166
stats = OptimizationBase.OptimizationStats(; iterations,
171167
time = t1 - t0, fevals, gevals)

0 commit comments

Comments
 (0)