@@ -65,20 +65,14 @@ function call_manopt_optimizer(
6565 loss,
6666 gradF,
6767 x0;
68- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
69- evaluation:: AbstractEvaluationType = Manopt. AllocatingEvaluation (),
70- stepsize:: Stepsize = ArmijoLinesearch (M),
7168 kwargs... )
72- opts = gradient_descent (M,
69+ opts = Manopt . gradient_descent (M,
7370 loss,
7471 gradF,
7572 x0;
76- return_state = true ,
77- evaluation,
78- stepsize,
79- stopping_criterion,
80- kwargs... )
81- # we unwrap DebugOptions here
73+ return_state = true , # return the (full, decorated) solver state
74+ kwargs...
75+ )
8276 minimizer = Manopt. get_solver_result (opts)
8377 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opts)
8478end
@@ -90,13 +84,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMea
9084 loss,
9185 gradF,
9286 x0;
93- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
94- kwargs... )
95- opts = NelderMead (M,
96- loss;
97- return_state = true ,
98- stopping_criterion,
9987 kwargs... )
88+ opts = NelderMead (M, loss; return_state = true , kwargs... )
10089 minimizer = Manopt. get_solver_result (opts)
10190 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opts)
10291end
@@ -109,19 +98,14 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
10998 loss,
11099 gradF,
111100 x0;
112- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
113- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
114- stepsize:: Stepsize = ArmijoLinesearch (M),
115101 kwargs... )
116- opts = conjugate_gradient_descent (M,
102+ opts = Manopt . conjugate_gradient_descent (M,
117103 loss,
118104 gradF,
119105 x0;
120106 return_state = true ,
121- evaluation,
122- stepsize,
123- stopping_criterion,
124- kwargs... )
107+ kwargs...
108+ )
125109 # we unwrap DebugOptions here
126110 minimizer = Manopt. get_solver_result (opts)
127111 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opts)
@@ -135,25 +119,10 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
135119 loss,
136120 gradF,
137121 x0;
138- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
139- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
140122 population_size:: Int = 100 ,
141- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
142- inverse_retraction_method:: AbstractInverseRetractionMethod = default_inverse_retraction_method (M),
143- vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
144- kwargs... )
145- initial_population = vcat ([x0], [rand (M) for _ in 1 : (population_size - 1 )])
146- opts = particle_swarm (M,
147- loss;
148- x0 = initial_population,
149- n = population_size,
150- return_state = true ,
151- retraction_method,
152- inverse_retraction_method,
153- vector_transport_method,
154- stopping_criterion,
155123 kwargs... )
156- # we unwrap DebugOptions here
124+ swarm = [x0, [rand (M) for _ in 1 : (population_size - 1 )]. .. ]
125+ opts = particle_swarm (M, loss, swarm; return_state = true , kwargs... )
157126 minimizer = Manopt. get_solver_result (opts)
158127 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opts)
159128end
@@ -167,27 +136,9 @@ function call_manopt_optimizer(M::Manopt.AbstractManifold,
167136 loss,
168137 gradF,
169138 x0;
170- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
171- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
172- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
173- vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
174- stepsize = WolfePowellLinesearch (M;
175- retraction_method = retraction_method,
176- vector_transport_method = vector_transport_method,
177- linesearch_stopsize = 1e-12 ),
178139 kwargs...
179140)
180- opts = quasi_Newton (M,
181- loss,
182- gradF,
183- x0;
184- return_state = true ,
185- evaluation,
186- retraction_method,
187- vector_transport_method,
188- stepsize,
189- stopping_criterion,
190- kwargs... )
141+ opts = quasi_Newton (M, loss, gradF, x0; return_state = true , kwargs... )
191142 # we unwrap DebugOptions here
192143 minimizer = Manopt. get_solver_result (opts)
193144 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opts)
@@ -200,18 +151,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
200151 loss,
201152 gradF,
202153 x0;
203- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
204- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
205- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
206- vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
207- basis = Manopt. DefaultOrthonormalBasis (),
208- kwargs... )
209- opt = cma_es (M,
210- loss,
211- x0;
212- return_state = true ,
213- stopping_criterion,
214154 kwargs... )
155+ opt = cma_es (M, loss, x0; return_state = true , kwargs... )
215156 # we unwrap DebugOptions here
216157 minimizer = Manopt. get_solver_result (opt)
217158 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opt)
@@ -224,21 +165,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
224165 loss,
225166 gradF,
226167 x0;
227- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
228- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
229- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
230- vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
231- kwargs... )
232- opt = convex_bundle_method! (M,
233- loss,
234- gradF,
235- x0;
236- return_state = true ,
237- evaluation,
238- retraction_method,
239- vector_transport_method,
240- stopping_criterion,
241168 kwargs... )
169+ opt = convex_bundle_method (M, loss, gradF, x0; return_state = true , kwargs... )
242170 # we unwrap DebugOptions here
243171 minimizer = Manopt. get_solver_result (opt)
244172 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opt)
@@ -252,21 +180,13 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
252180 gradF,
253181 x0;
254182 hessF = nothing ,
255- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
256- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
257- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
258- kwargs... )
259- opt = adaptive_regularization_with_cubics (M,
260- loss,
261- gradF,
262- hessF,
263- x0;
264- return_state = true ,
265- evaluation,
266- retraction_method,
267- stopping_criterion,
268183 kwargs... )
269- # we unwrap DebugOptions here
184+
185+ opt = if isnothing (hessF)
186+ adaptive_regularization_with_cubics (M, loss, gradF, x0; return_state = true , kwargs... )
187+ else
188+ adaptive_regularization_with_cubics (M, loss, gradF, hessF, x0; return_state = true , kwargs... )
189+ end
270190 minimizer = Manopt. get_solver_result (opt)
271191 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opt)
272192end
@@ -279,20 +199,12 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
279199 gradF,
280200 x0;
281201 hessF = nothing ,
282- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
283- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
284- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
285- kwargs... )
286- opt = trust_regions (M,
287- loss,
288- gradF,
289- hessF,
290- x0;
291- return_state = true ,
292- evaluation,
293- retraction = retraction_method,
294- stopping_criterion,
295202 kwargs... )
203+ opt = if isnothing (hessF)
204+ trust_regions (M, loss, gradF, x0; return_state = true , kwargs... )
205+ else
206+ trust_regions (M, loss, gradF, hessF, x0; return_state = true , kwargs... )
207+ end
296208 # we unwrap DebugOptions here
297209 minimizer = Manopt. get_solver_result (opt)
298210 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opt)
@@ -305,21 +217,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
305217 loss,
306218 gradF,
307219 x0;
308- stopping_criterion:: Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet} ,
309- evaluation:: AbstractEvaluationType = InplaceEvaluation (),
310- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
311- stepsize:: Stepsize = DecreasingStepsize (; length = 2.0 , shift = 2 ),
312- kwargs... )
313- opt = Frank_Wolfe_method (M,
314- loss,
315- gradF,
316- x0;
317- return_state = true ,
318- evaluation,
319- retraction_method,
320- stopping_criterion,
321- stepsize,
322220 kwargs... )
221+ opt = Frank_Wolfe_method (M, loss, gradF, x0; return_state = true , kwargs... )
323222 # we unwrap DebugOptions here
324223 minimizer = Manopt. get_solver_result (opt)
325224 return (; minimizer = minimizer, minimum = loss (M, minimizer), options = opt)
@@ -332,20 +231,22 @@ function SciMLBase.requiresgradient(opt::Union{
332231 AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
333232 true
334233end
234+ # TODO : WHY? they both still accept not passing it
335235function SciMLBase. requireshessian (opt:: Union {
336236 AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
337237 true
338238end
339239
340240function build_loss (f:: OptimizationFunction , prob, cb)
341- function (:: AbstractManifold , θ)
241+ return function (:: AbstractManifold , θ)
342242 x = f. f (θ, prob. p)
343243 cb (x, θ)
344244 __x = first (x)
345245 return prob. sense === Optimization. MaxSense ? - __x : __x
346246 end
347247end
348248
249+ # TODO : What does the “true” mean here?
349250function build_gradF (f:: OptimizationFunction{true} )
350251 function g (M:: AbstractManifold , G, θ)
351252 f. grad (G, θ)
@@ -356,6 +257,7 @@ function build_gradF(f::OptimizationFunction{true})
356257 f. grad (G, θ)
357258 return riemannian_gradient (M, θ, G)
358259 end
260+ return g
359261end
360262
361263function build_hessF (f:: OptimizationFunction{true} )
@@ -373,6 +275,7 @@ function build_hessF(f::OptimizationFunction{true})
373275 f. grad (G, θ)
374276 return riemannian_Hessian (M, θ, G, H, X)
375277 end
278+ return h
376279end
377280
378281function SciMLBase. __solve (cache:: OptimizationCache {
@@ -395,8 +298,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
395298 LC,
396299 UC,
397300 S,
398- O < :
399- AbstractManoptOptimizer,
301+ O <: AbstractManoptOptimizer ,
400302 D,
401303 P,
402304 C
@@ -418,6 +320,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
418320 u = θ,
419321 p = cache. p,
420322 objective = x[1 ])
323+ # TODO : What is this callback for?
421324 cb_call = cache. callback (opt_state, x... )
422325 if ! (cb_call isa Bool)
423326 error (" The callback should return a boolean `halt` for whether to stop the optimization process." )
@@ -448,10 +351,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
448351 stopping_criterion = Manopt. StopAfterIteration (500 )
449352 end
450353
354+ # TODO : With the new keyword warnings we can not just always pass down hessF!
451355 opt_res = call_manopt_optimizer (manifold, cache. opt, _loss, gradF, cache. u0;
452356 solver_kwarg... , stopping_criterion = stopping_criterion, hessF)
453357
454358 asc = get_stopping_criterion (opt_res. options)
359+ # TODO : Switch to `has_converged` once that was released.
455360 opt_ret = Manopt. indicates_convergence (asc) ? ReturnCode. Success : ReturnCode. Failure
456361
457362 return SciMLBase. build_solution (cache,
0 commit comments