Skip to content

Commit ac4d740

Browse files
Merge pull request #897 from prbzrg/fix-892
define `maxiters` similar to `epochs`
2 parents 4d00737 + 5926cd8 commit ac4d740

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ SciMLBase.requiresgradient(opt::AbstractRule) = true
99
SciMLBase.allowsfg(opt::AbstractRule) = true
1010

1111
function SciMLBase.__init(
12-
prob::SciMLBase.OptimizationProblem, opt::AbstractRule; save_best = true,
13-
callback = (args...) -> (false), epochs = nothing,
14-
progress = false, kwargs...)
15-
return OptimizationCache(prob, opt; save_best, callback, progress, epochs,
16-
kwargs...)
12+
prob::SciMLBase.OptimizationProblem, opt::AbstractRule;
13+
callback = (args...) -> (false),
14+
epochs::Union{Number, Nothing} = nothing,
15+
maxiters::Union{Number, Nothing} = nothing,
16+
save_best::Bool = true, progress::Bool = false, kwargs...)
17+
return OptimizationCache(prob, opt; callback, epochs, maxiters,
18+
save_best, progress, kwargs...)
1719
end
1820

1921
function SciMLBase.__solve(cache::OptimizationCache{
@@ -50,20 +52,27 @@ function SciMLBase.__solve(cache::OptimizationCache{
5052
dataiterate = false
5153
end
5254

53-
epochs = if cache.solver_args.epochs === nothing
54-
if cache.solver_args.maxiters === nothing
55-
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
55+
epochs, maxiters = if isnothing(cache.solver_args.maxiters) &&
56+
isnothing(cache.solver_args.epochs)
57+
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)."))
58+
elseif !isnothing(cache.solver_args.maxiters) &&
59+
!isnothing(cache.solver_args.epochs)
60+
if cache.solver_args.maxiters == cache.solver_args.epochs * length(data)
61+
cache.solver_args.epochs, cache.solver_args.maxiters
5662
else
57-
cache.solver_args.maxiters / length(data)
63+
throw(ArgumentError("Both maxiters and epochs were passed but maxiters != epochs * length(data)."))
5864
end
59-
else
60-
cache.solver_args.epochs
65+
elseif isnothing(cache.solver_args.maxiters)
66+
cache.solver_args.epochs, cache.solver_args.epochs * length(data)
67+
elseif isnothing(cache.solver_args.epochs)
68+
cache.solver_args.maxiters / length(data), cache.solver_args.maxiters
6169
end
62-
6370
epochs = Optimization._check_and_convert_maxiters(epochs)
64-
if epochs === nothing
65-
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
66-
end
71+
maxiters = Optimization._check_and_convert_maxiters(maxiters)
72+
73+
# At this point, both of them should be fine; but, let's assert it.
74+
@assert (!isnothing(epochs)&&!isnothing(maxiters) &&
75+
(maxiters == epochs * length(data))) "The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data)."
6776

6877
opt = cache.opt
6978
θ = copy(cache.u0)

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@ using Lux, MLUtils, Random, ComponentArrays, Printf, MLDataDevices
3232
@test sol.stats.fevals == 1000
3333
@test sol.stats.gevals == 1000
3434

35+
@testset "epochs & maxiters" begin
36+
optprob = SciMLBase.OptimizationFunction(
37+
(u, data) -> sum(u) + sum(data), Optimization.AutoZygote())
38+
prob = SciMLBase.OptimizationProblem(
39+
optprob, ones(2), MLUtils.DataLoader(ones(2, 2)))
40+
@test_throws ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).") solve(
41+
prob, Optimisers.Adam())
42+
@test_throws ArgumentError("Both maxiters and epochs were passed but maxiters != epochs * length(data).") solve(
43+
prob, Optimisers.Adam(), epochs = 2, maxiters = 2)
44+
sol = solve(prob, Optimisers.Adam(), epochs = 2)
45+
@test sol.stats.iterations == 4
46+
sol = solve(prob, Optimisers.Adam(), maxiters = 2)
47+
@test sol.stats.iterations == 2
48+
sol = solve(prob, Optimisers.Adam(), epochs = 2, maxiters = 4)
49+
@test sol.stats.iterations == 4
50+
@test_throws AssertionError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs * length(data).") solve(
51+
prob, Optimisers.Adam(), maxiters = 3)
52+
end
53+
3554
@testset "cache" begin
3655
objective(x, p) = (p[1] - x[1])^2
3756
x0 = zeros(1)

0 commit comments

Comments
 (0)