Skip to content

Commit c08bed2

Browse files
Merge pull request #1060 from prbzrg/rw-progressbar
Rewrite the progressbar part of `OptimizationOptimisers`
2 parents d201417 + e690813 commit c08bed2

File tree

4 files changed

+66
-70
lines changed

4 files changed

+66
-70
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1212
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
1313
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
1414
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
15-
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1615
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1716
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1817
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -58,7 +57,6 @@ OptimizationOptimisers = "0.3"
5857
OrdinaryDiffEqTsit5 = "1"
5958
Pkg = "1"
6059
Printf = "1.10"
61-
ProgressLogging = "0.1"
6260
Random = "1.10"
6361
Reexport = "1.2"
6462
ReverseDiff = "1"
@@ -109,6 +107,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
109107

110108
[targets]
111109
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff",
112-
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
110+
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
113111
"OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays",
114112
"Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]

lib/OptimizationOptimisers/Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ name = "OptimizationOptimisers"
22
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
44
version = "0.3.13"
5+
56
[deps]
67
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
7-
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
8-
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
98
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
109
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1110
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
11+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1212

1313
[extras]
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -19,14 +19,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2020
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2121
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
22+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2223

2324
[compat]
2425
julia = "1.10"
2526
OptimizationBase = "3"
26-
ProgressLogging = "0.1"
2727
SciMLBase = "2.58"
2828
Optimisers = "0.2, 0.3, 0.4"
2929
Reexport = "1.2"
30+
Logging = "1.10"
3031

3132
[targets]
32-
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
33+
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"]

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 59 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module OptimizationOptimisers
22

3-
using Reexport, Printf, ProgressLogging
3+
using Reexport, Logging
44
@reexport using Optimisers, OptimizationBase
55
using SciMLBase
66

@@ -95,77 +95,74 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
9595
gevals = 0
9696
t0 = time()
9797
breakall = false
98-
begin
99-
for epoch in 1:epochs
100-
if breakall
101-
break
98+
progress_id = :OptimizationOptimizersJL
99+
for epoch in 1:epochs, d in data
100+
if cache.f.fg !== nothing && dataiterate
101+
x = cache.f.fg(G, θ, d)
102+
iterations += 1
103+
fevals += 1
104+
gevals += 1
105+
elseif dataiterate
106+
cache.f.grad(G, θ, d)
107+
x = cache.f(θ, d)
108+
iterations += 1
109+
fevals += 2
110+
gevals += 1
111+
elseif cache.f.fg !== nothing
112+
x = cache.f.fg(G, θ)
113+
iterations += 1
114+
fevals += 1
115+
gevals += 1
116+
else
117+
cache.f.grad(G, θ)
118+
x = cache.f(θ)
119+
iterations += 1
120+
fevals += 2
121+
gevals += 1
122+
end
123+
opt_state = OptimizationBase.OptimizationState(
124+
iter = iterations,
125+
u = θ,
126+
p = d,
127+
objective = x[1],
128+
grad = G,
129+
original = state)
130+
breakall = cache.callback(opt_state, x...)
131+
if !(breakall isa Bool)
132+
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
133+
elseif breakall
134+
break
135+
end
136+
if cache.progress
137+
message = "Loss: $(round(first(first(x)); digits = 3))"
138+
@logmsg(LogLevel(-1), "Optimization", _id=progress_id,
139+
message=message, progress=iterations / maxiters)
140+
end
141+
if cache.solver_args.save_best
142+
if first(x)[1] < first(min_err)[1] #found a better solution
143+
min_opt = opt
144+
min_err = x
145+
min_θ = copy(θ)
102146
end
103-
for (i, d) in enumerate(data)
104-
if cache.f.fg !== nothing && dataiterate
105-
x = cache.f.fg(G, θ, d)
106-
iterations += 1
107-
fevals += 1
108-
gevals += 1
109-
elseif dataiterate
110-
cache.f.grad(G, θ, d)
111-
x = cache.f(θ, d)
112-
iterations += 1
113-
fevals += 2
114-
gevals += 1
115-
elseif cache.f.fg !== nothing
116-
x = cache.f.fg(G, θ)
117-
iterations += 1
118-
fevals += 1
119-
gevals += 1
120-
else
121-
cache.f.grad(G, θ)
122-
x = cache.f(θ)
123-
iterations += 1
124-
fevals += 2
125-
gevals += 1
126-
end
127-
opt_state = OptimizationBase.OptimizationState(
128-
iter = i + (epoch - 1) * length(data),
147+
if iterations == length(data) * epochs #Last iter, revert to best.
148+
opt = min_opt
149+
x = min_err
150+
θ = min_θ
151+
cache.f.grad(G, θ, d)
152+
opt_state = OptimizationBase.OptimizationState(iter = iterations,
129153
u = θ,
130154
p = d,
131155
objective = x[1],
132156
grad = G,
133157
original = state)
134158
breakall = cache.callback(opt_state, x...)
135-
if !(breakall isa Bool)
136-
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
137-
elseif breakall
138-
break
139-
end
140-
msg = @sprintf("loss: %.3g", first(x)[1])
141-
#cache.progress && ProgressLogging.@logprogress msg iterations/maxiters
142-
143-
if cache.solver_args.save_best
144-
if first(x)[1] < first(min_err)[1] #found a better solution
145-
min_opt = opt
146-
min_err = x
147-
min_θ = copy(θ)
148-
end
149-
if iterations == length(data) * epochs #Last iter, revert to best.
150-
opt = min_opt
151-
x = min_err
152-
θ = min_θ
153-
cache.f.grad(G, θ, d)
154-
opt_state = OptimizationBase.OptimizationState(iter = iterations,
155-
u = θ,
156-
p = d,
157-
objective = x[1],
158-
grad = G,
159-
original = state)
160-
breakall = cache.callback(opt_state, x...)
161-
break
162-
end
163-
end
164-
state, θ = Optimisers.update(state, θ, G)
159+
break
165160
end
166161
end
162+
state, θ = Optimisers.update(state, θ, G)
167163
end
168-
164+
cache.progress && @logmsg(LogLevel(-1), "Optimization",
165+
_id=progress_id, message="Done", progress=1.0)
169166
t1 = time()
170167
stats = OptimizationBase.OptimizationStats(; iterations,
171168
time = t1 - t0, fevals, gevals)

src/Optimization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ if !isdefined(Base, :get_extension)
1111
using Requires
1212
end
1313

14-
using Logging, ProgressLogging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
14+
using Logging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
1515
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
1616

1717
import OptimizationBase: instantiate_function, OptimizationCache, ReInitCache

0 commit comments

Comments
 (0)