Skip to content
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -58,7 +57,6 @@ OptimizationOptimisers = "0.3"
OrdinaryDiffEqTsit5 = "1"
Pkg = "1"
Printf = "1.10"
ProgressLogging = "0.1"
Random = "1.10"
Reexport = "1.2"
ReverseDiff = "1"
Expand Down Expand Up @@ -109,6 +107,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[targets]
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff",
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
"OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays",
"Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]
10 changes: 6 additions & 4 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ name = "OptimizationOptimisers"
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.3.13"

[deps]
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -19,14 +20,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
julia = "1.10"
OptimizationBase = "3"
ProgressLogging = "0.1"
SciMLBase = "2.58"
Optimisers = "0.2, 0.3, 0.4"
Reexport = "1.2"
Logging = "1.10"

[targets]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"]
121 changes: 59 additions & 62 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module OptimizationOptimisers

using Reexport, Printf, ProgressLogging
using Reexport, UUIDs, Logging
@reexport using Optimisers, OptimizationBase
using SciMLBase

Expand Down Expand Up @@ -95,77 +95,74 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
gevals = 0
t0 = time()
breakall = false
begin
for epoch in 1:epochs
if breakall
break
progress_id = uuid4()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about using a random UUID here. That as a symbol will intern. What's the reasoning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String or symbol? because we've never used a UUID there before, at least a string one. It at least before was a symbol, and we would get memory leaks if it was unique. So check if it's always a string?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for epoch in 1:epochs, d in data
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = OptimizationBase.OptimizationState(
iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
if !(breakall isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif breakall
break
end
if cache.progress
message = "Loss: $(round(first(first(x)); digits = 3))"
@logmsg(LogLevel(-1), "Optimization", _id=progress_id,
message=message, progress=iterations / maxiters)
end
if cache.solver_args.save_best
if first(x)[1] < first(min_err)[1] #found a better solution
min_opt = opt
min_err = x
min_θ = copy(θ)
end
for (i, d) in enumerate(data)
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = OptimizationBase.OptimizationState(
iter = i + (epoch - 1) * length(data),
if iterations == length(data) * epochs #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d)
opt_state = OptimizationBase.OptimizationState(iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
if !(breakall isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif breakall
break
end
msg = @sprintf("loss: %.3g", first(x)[1])
#cache.progress && ProgressLogging.@logprogress msg iterations/maxiters

if cache.solver_args.save_best
if first(x)[1] < first(min_err)[1] #found a better solution
min_opt = opt
min_err = x
min_θ = copy(θ)
end
if iterations == length(data) * epochs #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d)
opt_state = OptimizationBase.OptimizationState(iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
break
end
end
state, θ = Optimisers.update(state, θ, G)
break
end
end
state, θ = Optimisers.update(state, θ, G)
end

cache.progress && @logmsg(LogLevel(-1), "Optimization",
_id=progress_id, message="Done", progress=1.0)
t1 = time()
stats = OptimizationBase.OptimizationStats(; iterations,
time = t1 - t0, fevals, gevals)
Expand Down
2 changes: 1 addition & 1 deletion src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ if !isdefined(Base, :get_extension)
using Requires
end

using Logging, ProgressLogging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
using Logging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra

import OptimizationBase: instantiate_function, OptimizationCache, ReInitCache
Expand Down
Loading