Skip to content

Commit 2c6927d

Browse files
Remove regex matching and use strings for retcode determination
1 parent 3cf9340 commit 2c6927d

File tree

5 files changed

+58
-40
lines changed

5 files changed

+58
-40
lines changed

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,10 @@ function SciMLBase.__solve(cache::OptimizationCache{
394394
local x, cur, state
395395

396396
manifold = haskey(cache.solver_args, :manifold) ? cache.solver_args[:manifold] : nothing
397-
gradF = haskey(cache.solver_args, :riemannian_grad) ? cache.solver_args[:riemannian_grad] : nothing
398-
hessF = haskey(cache.solver_args, :riemannian_hess) ? cache.solver_args[:riemannian_hess] : nothing
397+
gradF = haskey(cache.solver_args, :riemannian_grad) ?
398+
cache.solver_args[:riemannian_grad] : nothing
399+
hessF = haskey(cache.solver_args, :riemannian_hess) ?
400+
cache.solver_args[:riemannian_hess] : nothing
399401

400402
if manifold === nothing
401403
throw(ArgumentError("Manifold not specified in the problem for e.g. `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))`."))

src/lbfgsb.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
171171
n = length(cache.u0)
172172

173173
if cache.lb === nothing
174-
optimizer, bounds = LBFGSB._opt_bounds(n, cache.opt.m, [-Inf for i in 1:n], [Inf for i in 1:n])
174+
optimizer, bounds = LBFGSB._opt_bounds(
175+
n, cache.opt.m, [-Inf for i in 1:n], [Inf for i in 1:n])
175176
else
176-
optimizer, bounds = LBFGSB._opt_bounds(n, cache.opt.m, solver_kwargs.lb, solver_kwargs.ub)
177+
optimizer, bounds = LBFGSB._opt_bounds(
178+
n, cache.opt.m, solver_kwargs.lb, solver_kwargs.ub)
177179
end
178180

179181
solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))
@@ -182,7 +184,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
182184
prev_eqcons .= cons_tmp[eq_inds]
183185
prevβ .= copy(β)
184186

185-
res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs..., m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100)
187+
res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs...,
188+
m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100)
186189
# @show res[2]
187190
# @show res[1]
188191
# @show cons_tmp
@@ -211,7 +214,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
211214
stats = Optimization.OptimizationStats(; iterations = maxiters,
212215
time = 0.0, fevals = maxiters, gevals = maxiters)
213216
return SciMLBase.build_solution(
214-
cache, cache.opt, res[2], cache.f(res[2], cache.p)[1], stats = stats, retcode = opt_ret)
217+
cache, cache.opt, res[2], cache.f(res[2], cache.p)[1],
218+
stats = stats, retcode = opt_ret)
215219
else
216220
_loss = function (θ)
217221
x = cache.f(θ, cache.p)
@@ -226,16 +230,19 @@ function SciMLBase.__solve(cache::OptimizationCache{
226230
n = length(cache.u0)
227231

228232
if cache.lb === nothing
229-
optimizer, bounds= LBFGSB._opt_bounds(n, cache.opt.m, [-Inf for i in 1:n], [Inf for i in 1:n])
233+
optimizer, bounds = LBFGSB._opt_bounds(
234+
n, cache.opt.m, [-Inf for i in 1:n], [Inf for i in 1:n])
230235
else
231-
optimizer, bounds= LBFGSB._opt_bounds(n, cache.opt.m, solver_kwargs.lb, solver_kwargs.ub)
236+
optimizer, bounds = LBFGSB._opt_bounds(
237+
n, cache.opt.m, solver_kwargs.lb, solver_kwargs.ub)
232238
end
233239

234240
solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))
235241

236242
t0 = time()
237243

238-
res = optimizer(_loss, cache.f.grad, cache.u0, bounds; m = cache.opt.m, solver_kwargs...)
244+
res = optimizer(
245+
_loss, cache.f.grad, cache.u0, bounds; m = cache.opt.m, solver_kwargs...)
239246

240247
# Extract the task message from the result
241248
stop_reason = task_message_to_string(optimizer.task)
@@ -247,6 +254,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
247254
stats = Optimization.OptimizationStats(; iterations = maxiters,
248255
time = t1 - t0, fevals = maxiters, gevals = maxiters)
249256

250-
return SciMLBase.build_solution(cache, cache.opt, res[2], res[1], stats = stats, retcode = opt_ret)
257+
return SciMLBase.build_solution(cache, cache.opt, res[2], res[1], stats = stats,
258+
retcode = opt_ret, original = optimizer)
251259
end
252260
end

src/utils.jl

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -67,46 +67,45 @@ function check_pkg_version(pkg::String, ver::String;
6767
pkg_info[pkg].version > VersionNumber(ver)
6868
end
6969

70-
7170
# RetCode handling for BBO and others.
7271
using SciMLBase: ReturnCode
7372

7473
# Define a dictionary to map regular expressions to ReturnCode values
7574
const STOP_REASON_MAP = Dict(
76-
r"Delta fitness .* below tolerance .*" => ReturnCode.Success,
77-
r"Fitness .* within tolerance .* of optimum" => ReturnCode.Success,
78-
r"CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL" => ReturnCode.Success,
79-
r"Unrecognized stop reason: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH" => ReturnCode.Success,
80-
r"Terminated" => ReturnCode.Terminated,
81-
r"MaxIters|MAXITERS_EXCEED|Max number of steps .* reached" => ReturnCode.MaxIters,
82-
r"MaxTime|TIME_LIMIT" => ReturnCode.MaxTime,
83-
r"Max time" => ReturnCode.MaxTime,
84-
r"DtLessThanMin" => ReturnCode.DtLessThanMin,
85-
r"Unstable" => ReturnCode.Unstable,
86-
r"InitialFailure" => ReturnCode.InitialFailure,
87-
r"ConvergenceFailure|ITERATION_LIMIT" => ReturnCode.ConvergenceFailure,
88-
r"Infeasible|INFEASIBLE|DUAL_INFEASIBLE|LOCALLY_INFEASIBLE|INFEASIBLE_OR_UNBOUNDED" => ReturnCode.Infeasible,
89-
r"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT" => ReturnCode.MaxIters,
90-
r"STOP: TOTAL NO. of f AND g EVALUATIONS EXCEEDS LIMIT" => ReturnCode.MaxIters,
91-
r"STOP: ABNORMAL_TERMINATION_IN_LNSRCH" => ReturnCode.Unstable,
92-
r"STOP: ERROR INPUT DATA" => ReturnCode.InitialFailure,
93-
r"STOP: FTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
94-
r"STOP: GTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
95-
r"STOP: XTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
96-
r"STOP: TERMINATION" => ReturnCode.Terminated,
97-
r"Optimization completed" => ReturnCode.Success,
98-
r"Convergence achieved" => ReturnCode.Success
75+
"Delta fitness .* below tolerance .*" => ReturnCode.Success,
76+
"Fitness .* within tolerance .* of optimum" => ReturnCode.Success,
77+
"CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL" => ReturnCode.Success,
78+
"CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH" => ReturnCode.Success,
79+
"Terminated" => ReturnCode.Terminated,
80+
"MaxIters|MAXITERS_EXCEED|Max number of steps .* reached" => ReturnCode.MaxIters,
81+
"MaxTime|TIME_LIMIT" => ReturnCode.MaxTime,
82+
"Max time" => ReturnCode.MaxTime,
83+
"DtLessThanMin" => ReturnCode.DtLessThanMin,
84+
"Unstable" => ReturnCode.Unstable,
85+
"InitialFailure" => ReturnCode.InitialFailure,
86+
"ConvergenceFailure|ITERATION_LIMIT" => ReturnCode.ConvergenceFailure,
87+
"Infeasible|INFEASIBLE|DUAL_INFEASIBLE|LOCALLY_INFEASIBLE|INFEASIBLE_OR_UNBOUNDED" => ReturnCode.Infeasible,
88+
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT" => ReturnCode.MaxIters,
89+
"STOP: TOTAL NO. of f AND g EVALUATIONS EXCEEDS LIMIT" => ReturnCode.MaxIters,
90+
"STOP: ABNORMAL_TERMINATION_IN_LNSRCH" => ReturnCode.Unstable,
91+
"STOP: ERROR INPUT DATA" => ReturnCode.InitialFailure,
92+
"STOP: FTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
93+
"STOP: GTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
94+
"STOP: XTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
95+
"STOP: TERMINATION" => ReturnCode.Terminated,
96+
"Optimization completed" => ReturnCode.Success,
97+
"Convergence achieved" => ReturnCode.Success
9998
)
10099

101100
# Function to deduce ReturnCode from a stop_reason string using the dictionary
102101
function deduce_retcode(stop_reason::String)
103102
for (pattern, retcode) in STOP_REASON_MAP
104103
if occursin(pattern, stop_reason)
105-
return retcode
104+
return retcode
106105
end
107106
end
108-
@warn "Unrecognized stop reason: $stop_reason. Defaulting to ReturnCode.Failure."
109-
return ReturnCode.Failure
107+
@warn "Unrecognized stop reason: $stop_reason. Defaulting to ReturnCode.Default."
108+
return ReturnCode.Default
110109
end
111110

112111
# Function to deduce ReturnCode from a Symbol
@@ -141,4 +140,3 @@ function deduce_retcode(retcode::Symbol)
141140
return ReturnCode.Failure
142141
end
143142
end
144-

test/ADtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,15 @@ prob = OptimizationProblem(optf, x0)
202202

203203
sol = solve(prob, Optim.BFGS())
204204
@test 10 * sol.objective < l1
205+
@test sol.retcode == ReturnCode.Success
205206

206207
sol = solve(prob, Optim.Newton())
207208
@test 10 * sol.objective < l1
209+
@test sol.retcode == ReturnCode.Success
208210

209211
sol = solve(prob, Optim.KrylovTrustRegion())
210212
@test 10 * sol.objective < l1
213+
@test sol.retcode == ReturnCode.Success
211214

212215
optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
213216
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoZygote(), nothing)
@@ -403,10 +406,12 @@ for consf in [cons, con2_c]
403406
prob1 = OptimizationProblem(optf1, [0.3, 0.5], lb = [0.2, 0.4], ub = [0.6, 0.8],
404407
lcons = lcons, ucons = ucons)
405408
sol1 = solve(prob1, Optim.IPNewton())
409+
@test sol1.retcode == ReturnCode.Success
406410
optf2 = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff(); cons = consf)
407411
prob2 = OptimizationProblem(optf2, [0.3, 0.5], lb = [0.2, 0.4], ub = [0.6, 0.8],
408412
lcons = lcons, ucons = ucons)
409413
sol2 = solve(prob2, Optim.IPNewton())
414+
@test sol2.retcode == ReturnCode.Success
410415
@test sol1.objectivesol2.objective rtol=1e-4
411416
@test sol1.u sol2.u
412417
res = Array{Float64}(undef, length(lcons))
@@ -421,9 +426,11 @@ for consf in [cons, con2_c]
421426
optf1 = OptimizationFunction(rosenbrock, Optimization.AutoFiniteDiff(); cons = consf)
422427
prob1 = OptimizationProblem(optf1, [0.5, 0.5], lcons = lcons, ucons = ucons)
423428
sol1 = solve(prob1, Optim.IPNewton())
429+
@test sol1.retcode == ReturnCode.Success
424430
optf2 = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff(); cons = consf)
425431
prob2 = OptimizationProblem(optf2, [0.5, 0.5], lcons = lcons, ucons = ucons)
426432
sol2 = solve(prob2, Optim.IPNewton())
433+
@test sol2.retcode == ReturnCode.Success
427434
@test sol1.objectivesol2.objective rtol=1e-4
428435
@test sol1.usol2.u rtol=1e-4
429436
res = Array{Float64}(undef, length(lcons))

test/lbfgsb.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ x0 = zeros(2)
66
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
77
l1 = rosenbrock(x0)
88

9-
optf = OptimizationFunction(rosenbrock, AutoEnzyme())
9+
optf = OptimizationFunction(rosenbrock, AutoForwardDiff())
1010
prob = OptimizationProblem(optf, x0)
1111
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
12+
@test res.retcode == Optimization.SciMLBase.ReturnCode.Success
1213

13-
prob = OptimizationProblem(optf, x0, lb = [-1.0, -1.0], ub = [1.0, 1.0])
14+
prob = OptimizationProblem(optf, x0, lb = [-1.0, -1.0], ub = [1.0, 1.0])
1415
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
16+
@test res.retcode == Optimization.SciMLBase.ReturnCode.Success
1517

1618
function con2_c(res, x, p)
1719
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1]) - 5]
@@ -22,3 +24,4 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
2224
ucons = [1.0, 0.0], lb = [-1.0, -1.0],
2325
ub = [1.0, 1.0])
2426
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
27+
@test res.retcode == Optimization.SciMLBase.ReturnCode.MaxIters

0 commit comments

Comments
 (0)