@@ -22,26 +22,32 @@ function __map_optimizer_args!(cache::OptimizationCache,
2222 abstol:: Union{Number, Nothing} = nothing ,
2323 reltol:: Union{Number, Nothing} = nothing ,
2424 kwargs... )
25-
2625 solver_kwargs = (; kwargs... )
2726
2827 if ! isnothing (maxiters)
29- solver_kwargs = (; solver_kwargs... , stopping_criterion = [Manopt. StopAfterIteration (maxiters)])
28+ solver_kwargs = (;
29+ solver_kwargs... , stopping_criterion = [Manopt. StopAfterIteration (maxiters)])
3030 end
3131
3232 if ! isnothing (maxtime)
3333 if haskey (solver_kwargs, :stopping_criterion )
34- solver_kwargs = (; solver_kwargs... , stopping_criterion = push! (solver_kwargs. stopping_criterion, Manopt. StopAfterTime (maxtime)))
34+ solver_kwargs = (; solver_kwargs... ,
35+ stopping_criterion = push! (
36+ solver_kwargs. stopping_criterion, Manopt. StopAfterTime (maxtime)))
3537 else
36- solver_kwargs = (; solver_kwargs... , stopping_criterion = [Manopt. StopAfter (maxtime)])
38+ solver_kwargs = (;
39+ solver_kwargs... , stopping_criterion = [Manopt. StopAfter (maxtime)])
3740 end
3841 end
3942
4043 if ! isnothing (abstol)
4144 if haskey (solver_kwargs, :stopping_criterion )
42- solver_kwargs = (; solver_kwargs... , stopping_criterion = push! (solver_kwargs. stopping_criterion, Manopt. StopWhenChangeLess (abstol)))
45+ solver_kwargs = (; solver_kwargs... ,
46+ stopping_criterion = push! (
47+ solver_kwargs. stopping_criterion, Manopt. StopWhenChangeLess (abstol)))
4348 else
44- solver_kwargs = (; solver_kwargs... , stopping_criterion = [Manopt. StopWhenChangeLess (abstol)])
49+ solver_kwargs = (;
50+ solver_kwargs... , stopping_criterion = [Manopt. StopWhenChangeLess (abstol)])
4551 end
4652 end
4753
5460# # gradient descent
5561struct GradientDescentOptimizer <: AbstractManoptOptimizer end
5662
57- function call_manopt_optimizer (M:: ManifoldsBase.AbstractManifold , opt:: GradientDescentOptimizer ,
63+ function call_manopt_optimizer (
64+ M:: ManifoldsBase.AbstractManifold , opt:: GradientDescentOptimizer ,
5865 loss,
5966 gradF,
6067 x0;
@@ -84,7 +91,6 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMea
8491 x0;
8592 stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
8693 kwargs... )
87-
8894 opts = NelderMead (M,
8995 loss;
9096 return_state = true ,
96102# # conjugate gradient descent
97103struct ConjugateGradientDescentOptimizer <: AbstractManoptOptimizer end
98104
99- function call_manopt_optimizer (M:: ManifoldsBase.AbstractManifold ,
105+ function call_manopt_optimizer (M:: ManifoldsBase.AbstractManifold ,
100106 opt:: ConjugateGradientDescentOptimizer ,
101107 loss,
102108 gradF,
@@ -105,7 +111,6 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
105111 evaluation:: AbstractEvaluationType = InplaceEvaluation (),
106112 stepsize:: Stepsize = ArmijoLinesearch (M),
107113 kwargs... )
108-
109114 opts = conjugate_gradient_descent (M,
110115 loss,
111116 gradF,
@@ -134,7 +139,6 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
134139 inverse_retraction_method:: AbstractInverseRetractionMethod = default_inverse_retraction_method (M),
135140 vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
136141 kwargs... )
137-
138142 initial_population = vcat ([x0], [rand (M) for _ in 1 : (population_size - 1 )])
139143 opts = particle_swarm (M,
140144 loss;
@@ -168,8 +172,7 @@ function call_manopt_optimizer(M::Manopt.AbstractManifold,
168172 vector_transport_method = vector_transport_method,
169173 linesearch_stopsize = 1e-12 ),
170174 kwargs...
171- )
172-
175+ )
173176 opts = quasi_Newton (M,
174177 loss,
175178 gradF,
@@ -214,30 +217,30 @@ end
214217# 3) add callbacks to Manopt.jl
215218
216219function SciMLBase. __solve (cache:: OptimizationCache {
217- F,
218- RC,
219- LB,
220- UB,
221- LC,
222- UC,
223- S,
224- O,
225- D,
226- P,
227- C
220+ F,
221+ RC,
222+ LB,
223+ UB,
224+ LC,
225+ UC,
226+ S,
227+ O,
228+ D,
229+ P,
230+ C
228231}) where {
229- F,
230- RC,
231- LB,
232- UB,
233- LC,
234- UC,
235- S,
236- O < :
237- AbstractManoptOptimizer,
238- D,
239- P,
240- C
232+ F,
233+ RC,
234+ LB,
235+ UB,
236+ LC,
237+ UC,
238+ S,
239+ O < :
240+ AbstractManoptOptimizer,
241+ D,
242+ P,
243+ C
241244}
242245 local x, cur, state
243246
@@ -272,11 +275,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
272275 end
273276 end
274277 solver_kwarg = __map_optimizer_args! (cache, cache. opt, callback = _cb,
275- maxiters = maxiters,
276- maxtime = cache. solver_args. maxtime,
277- abstol = cache. solver_args. abstol,
278- reltol = cache. solver_args. reltol;
279- )
278+ maxiters = maxiters,
279+ maxtime = cache. solver_args. maxtime,
280+ abstol = cache. solver_args. abstol,
281+ reltol = cache. solver_args. reltol;
282+ )
280283
281284 _loss = build_loss (cache. f, cache, _cb)
282285
@@ -288,11 +291,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
288291 stopping_criterion = Manopt. StopAfterIteration (500 )
289292 end
290293
291- opt_res = call_manopt_optimizer (manifold, cache. opt, _loss, gradF, cache. u0; solver_kwarg... , stopping_criterion= stopping_criterion)
294+ opt_res = call_manopt_optimizer (manifold, cache. opt, _loss, gradF, cache. u0;
295+ solver_kwarg... , stopping_criterion = stopping_criterion)
292296
293297 asc = get_active_stopping_criteria (opt_res. options. stop)
294298
295- opt_ret = any (Manopt. indicates_convergence, asc) ? ReturnCode. Success : ReturnCode. Failure
299+ opt_ret = any (Manopt. indicates_convergence, asc) ? ReturnCode. Success :
300+ ReturnCode. Failure
296301
297302 return SciMLBase. build_solution (cache,
298303 cache. opt,
0 commit comments