Skip to content

Commit 020c967

Browse files
testing and format
1 parent 80d5671 commit 020c967

File tree

3 files changed

+52
-46
lines changed

3 files changed

+52
-46
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ jobs:
2424
- OptimizationEvolutionary
2525
- OptimizationFlux
2626
- OptimizationGCMAES
27+
- OptimizationManopt
2728
- OptimizationMetaheuristics
2829
- OptimizationMOI
2930
- OptimizationMultistartOptimization
3031
- OptimizationNLopt
31-
#- OptimizationNonconvex
3232
- OptimizationNOMAD
3333
- OptimizationOptimJL
3434
- OptimizationOptimisers

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,32 @@ function __map_optimizer_args!(cache::OptimizationCache,
2222
abstol::Union{Number, Nothing} = nothing,
2323
reltol::Union{Number, Nothing} = nothing,
2424
kwargs...)
25-
2625
solver_kwargs = (; kwargs...)
2726

2827
if !isnothing(maxiters)
29-
solver_kwargs = (; solver_kwargs..., stopping_criterion = [Manopt.StopAfterIteration(maxiters)])
28+
solver_kwargs = (;
29+
solver_kwargs..., stopping_criterion = [Manopt.StopAfterIteration(maxiters)])
3030
end
3131

3232
if !isnothing(maxtime)
3333
if haskey(solver_kwargs, :stopping_criterion)
34-
solver_kwargs = (; solver_kwargs..., stopping_criterion = push!(solver_kwargs.stopping_criterion, Manopt.StopAfterTime(maxtime)))
34+
solver_kwargs = (; solver_kwargs...,
35+
stopping_criterion = push!(
36+
solver_kwargs.stopping_criterion, Manopt.StopAfterTime(maxtime)))
3537
else
36-
solver_kwargs = (; solver_kwargs..., stopping_criterion = [Manopt.StopAfter(maxtime)])
38+
solver_kwargs = (;
39+
solver_kwargs..., stopping_criterion = [Manopt.StopAfter(maxtime)])
3740
end
3841
end
3942

4043
if !isnothing(abstol)
4144
if haskey(solver_kwargs, :stopping_criterion)
42-
solver_kwargs = (; solver_kwargs..., stopping_criterion = push!(solver_kwargs.stopping_criterion, Manopt.StopWhenChangeLess(abstol)))
45+
solver_kwargs = (; solver_kwargs...,
46+
stopping_criterion = push!(
47+
solver_kwargs.stopping_criterion, Manopt.StopWhenChangeLess(abstol)))
4348
else
44-
solver_kwargs = (; solver_kwargs..., stopping_criterion = [Manopt.StopWhenChangeLess(abstol)])
49+
solver_kwargs = (;
50+
solver_kwargs..., stopping_criterion = [Manopt.StopWhenChangeLess(abstol)])
4551
end
4652
end
4753

@@ -54,7 +60,8 @@ end
5460
## gradient descent
5561
struct GradientDescentOptimizer <: AbstractManoptOptimizer end
5662

57-
function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::GradientDescentOptimizer,
63+
function call_manopt_optimizer(
64+
M::ManifoldsBase.AbstractManifold, opt::GradientDescentOptimizer,
5865
loss,
5966
gradF,
6067
x0;
@@ -84,7 +91,6 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMea
8491
x0;
8592
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
8693
kwargs...)
87-
8894
opts = NelderMead(M,
8995
loss;
9096
return_state = true,
@@ -96,7 +102,7 @@ end
96102
## conjugate gradient descent
97103
struct ConjugateGradientDescentOptimizer <: AbstractManoptOptimizer end
98104

99-
function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
105+
function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
100106
opt::ConjugateGradientDescentOptimizer,
101107
loss,
102108
gradF,
@@ -105,7 +111,6 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
105111
evaluation::AbstractEvaluationType = InplaceEvaluation(),
106112
stepsize::Stepsize = ArmijoLinesearch(M),
107113
kwargs...)
108-
109114
opts = conjugate_gradient_descent(M,
110115
loss,
111116
gradF,
@@ -134,7 +139,6 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
134139
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
135140
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
136141
kwargs...)
137-
138142
initial_population = vcat([x0], [rand(M) for _ in 1:(population_size - 1)])
139143
opts = particle_swarm(M,
140144
loss;
@@ -168,8 +172,7 @@ function call_manopt_optimizer(M::Manopt.AbstractManifold,
168172
vector_transport_method = vector_transport_method,
169173
linesearch_stopsize = 1e-12),
170174
kwargs...
171-
)
172-
175+
)
173176
opts = quasi_Newton(M,
174177
loss,
175178
gradF,
@@ -214,30 +217,30 @@ end
214217
# 3) add callbacks to Manopt.jl
215218

216219
function SciMLBase.__solve(cache::OptimizationCache{
217-
F,
218-
RC,
219-
LB,
220-
UB,
221-
LC,
222-
UC,
223-
S,
224-
O,
225-
D,
226-
P,
227-
C
220+
F,
221+
RC,
222+
LB,
223+
UB,
224+
LC,
225+
UC,
226+
S,
227+
O,
228+
D,
229+
P,
230+
C
228231
}) where {
229-
F,
230-
RC,
231-
LB,
232-
UB,
233-
LC,
234-
UC,
235-
S,
236-
O <:
237-
AbstractManoptOptimizer,
238-
D,
239-
P,
240-
C
232+
F,
233+
RC,
234+
LB,
235+
UB,
236+
LC,
237+
UC,
238+
S,
239+
O <:
240+
AbstractManoptOptimizer,
241+
D,
242+
P,
243+
C
241244
}
242245
local x, cur, state
243246

@@ -272,11 +275,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
272275
end
273276
end
274277
solver_kwarg = __map_optimizer_args!(cache, cache.opt, callback = _cb,
275-
maxiters = maxiters,
276-
maxtime = cache.solver_args.maxtime,
277-
abstol = cache.solver_args.abstol,
278-
reltol = cache.solver_args.reltol;
279-
)
278+
maxiters = maxiters,
279+
maxtime = cache.solver_args.maxtime,
280+
abstol = cache.solver_args.abstol,
281+
reltol = cache.solver_args.reltol;
282+
)
280283

281284
_loss = build_loss(cache.f, cache, _cb)
282285

@@ -288,11 +291,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
288291
stopping_criterion = Manopt.StopAfterIteration(500)
289292
end
290293

291-
opt_res = call_manopt_optimizer(manifold, cache.opt, _loss, gradF, cache.u0; solver_kwarg..., stopping_criterion=stopping_criterion)
294+
opt_res = call_manopt_optimizer(manifold, cache.opt, _loss, gradF, cache.u0;
295+
solver_kwarg..., stopping_criterion = stopping_criterion)
292296

293297
asc = get_active_stopping_criteria(opt_res.options.stop)
294298

295-
opt_ret = any(Manopt.indicates_convergence, asc) ? ReturnCode.Success : ReturnCode.Failure
299+
opt_ret = any(Manopt.indicates_convergence, asc) ? ReturnCode.Success :
300+
ReturnCode.Failure
296301

297302
return SciMLBase.build_solution(cache,
298303
cache.opt,

lib/OptimizationManopt/test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ end
3737
opt = OptimizationManopt.GradientDescentOptimizer()
3838

3939
optprob_forwarddiff = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme())
40-
prob_forwarddiff = OptimizationProblem(optprob_forwarddiff, x0, p; manifold = R2, stepsize = stepsize)
40+
prob_forwarddiff = OptimizationProblem(
41+
optprob_forwarddiff, x0, p; manifold = R2, stepsize = stepsize)
4142
sol = Optimization.solve(prob_forwarddiff, opt)
4243
@test sol.minimum < 0.2
4344

@@ -132,5 +133,5 @@ end
132133
opt = OptimizationManopt.GradientDescentOptimizer()
133134
@time sol = Optimization.solve(prob, opt)
134135

135-
@test sol.u q atol = 1e-2
136+
@test sol.uq atol=1e-2
136137
end

0 commit comments

Comments
 (0)