Skip to content

Commit 47a2481

Browse files
MOI vector lambda and iteration fixes in Optimisers
1 parent c526d71 commit 47a2481

File tree

4 files changed

+43
-22
lines changed

4 files changed

+43
-22
lines changed

lib/OptimizationMOI/src/nlp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ function MOI.eval_hessian_lagrangian(evaluator::MOIOptimizationNLPEvaluator{T},
375375
σ,
376376
μ) where {T}
377377
if evaluator.f.lag_h !== nothing
378-
evaluator.f.lag_h(h, x, σ, μ)
378+
evaluator.f.lag_h(h, x, σ, Vector(μ))
379379
return
380380
end
381381
if evaluator.f.hess === nothing

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,27 @@ function SciMLBase.__solve(cache::OptimizationCache{
4242
P,
4343
C
4444
}
45-
maxiters = if cache.solver_args.epochs === nothing
45+
if OptimizationBase.isa_dataiterator(cache.p)
46+
data = cache.p
47+
dataiterate = true
48+
else
49+
data = [cache.p]
50+
dataiterate = false
51+
end
52+
53+
epochs = if cache.solver_args.epochs === nothing
4654
if cache.solver_args.maxiters === nothing
47-
throw(ArgumentError("The number of epochs must be specified with either the epochs or maxiters kwarg."))
55+
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
4856
else
49-
cache.solver_args.maxiters
57+
cache.solver_args.maxiters / length(data)
5058
end
5159
else
5260
cache.solver_args.epochs
5361
end
5462

55-
maxiters = Optimization._check_and_convert_maxiters(maxiters)
56-
if maxiters === nothing
57-
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
58-
end
59-
60-
if OptimizationBase.isa_dataiterator(cache.p)
61-
data = cache.p
62-
dataiterate = true
63-
else
64-
data = [cache.p]
65-
dataiterate = false
63+
epochs = Optimization._check_and_convert_maxiters(epochs)
64+
if epochs === nothing
65+
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
6666
end
6767

6868
opt = cache.opt
@@ -75,21 +75,35 @@ function SciMLBase.__solve(cache::OptimizationCache{
7575
min_θ = cache.u0
7676

7777
state = Optimisers.setup(opt, θ)
78-
78+
iterations = 0
79+
fevals = 0
80+
gevals = 0
7981
t0 = time()
8082
Optimization.@withprogress cache.progress name="Training" begin
81-
for epoch in 1:maxiters
83+
for epoch in 1:epochs
8284
for (i, d) in enumerate(data)
8385
if cache.f.fg !== nothing && dataiterate
8486
x = cache.f.fg(G, θ, d)
87+
iterations += 1
88+
fevals += 1
89+
gevals += 1
8590
elseif dataiterate
8691
cache.f.grad(G, θ, d)
8792
x = cache.f(θ, d)
93+
iterations += 1
94+
fevals += 2
95+
gevals += 1
8896
elseif cache.f.fg !== nothing
8997
x = cache.f.fg(G, θ)
98+
iterations += 1
99+
fevals += 1
100+
gevals += 1
90101
else
91102
cache.f.grad(G, θ)
92103
x = cache.f(θ)
104+
iterations += 1
105+
fevals += 2
106+
gevals += 1
93107
end
94108
opt_state = Optimization.OptimizationState(
95109
iter = i + (epoch - 1) * length(data),
@@ -112,7 +126,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
112126
min_err = x
113127
min_θ = copy(θ)
114128
end
115-
if i == maxiters #Last iter, revert to best.
129+
if i == length(data) #Last iter, revert to best.
116130
opt = min_opt
117131
x = min_err
118132
θ = min_θ
@@ -132,10 +146,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
132146
end
133147

134148
t1 = time()
135-
stats = Optimization.OptimizationStats(; iterations = maxiters,
136-
time = t1 - t0, fevals = maxiters, gevals = maxiters)
149+
stats = Optimization.OptimizationStats(; iterations,
150+
time = t1 - t0, fevals, gevals)
137151
SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], stats = stats)
138-
# here should be build_solution to create the output message
139152
end
140153

141154
end

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ using Zygote
2727

2828
sol = solve(prob, Optimisers.Adam(), maxiters = 1000)
2929
@test 10 * sol.objective < l1
30+
@test sol.stats.iterations == 1000
31+
@test sol.stats.fevals == 1000
32+
@test sol.stats.gevals == 1000
3033

3134
@testset "cache" begin
3235
objective(x, p) = (p[1] - x[1])^2
@@ -99,6 +102,10 @@ end
99102
res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)
100103

101104
@test res.objective < 1e-4
105+
@test res.stats.iterations == 10000*length(data)
106+
@test res.stats.fevals == 10000*length(data)
107+
@test res.stats.gevals == 10000*length(data)
108+
102109

103110
using MLDataDevices
104111
data = CPUDevice()(data)

src/sophia.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
8888
cache.f.grad(gₜ, θ)
8989
x = cache.f(θ)
9090
end
91-
opt_state = Optimization.OptimizationState(; iter = i + (epoch - 1) * length(data),
91+
opt_state = Optimization.OptimizationState(;
92+
iter = i + (epoch - 1) * length(data),
9293
u = θ,
9394
objective = first(x),
9495
grad = gₜ,

0 commit comments

Comments
 (0)