@@ -20,113 +20,112 @@ end
2020# # gradient descent
2121
2222struct GradientDescentOptimizer{
23- Teval <: AbstractEvaluationType ,
24- TM <: AbstractManifold ,
25- TLS <: Linesearch
26- } <: AbstractManoptOptimizer
23+ Teval <: AbstractEvaluationType ,
24+ TM <: AbstractManifold ,
25+ TLS <: Linesearch
26+ } <: AbstractManoptOptimizer
2727 M:: TM
2828 stepsize:: TLS
2929end
3030
3131function GradientDescentOptimizer (M:: AbstractManifold ;
32- eval:: AbstractEvaluationType = Manopt. AllocatingEvaluation (),
33- stepsize:: Stepsize = ArmijoLinesearch (M))
32+ eval:: AbstractEvaluationType = Manopt. AllocatingEvaluation (),
33+ stepsize:: Stepsize = ArmijoLinesearch (M))
3434 GradientDescentOptimizer {typeof(eval), typeof(M), typeof(stepsize)} (M, stepsize)
3535end
3636
3737function call_manopt_optimizer (opt:: GradientDescentOptimizer{Teval} ,
38- loss,
39- gradF,
40- x0,
41- stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
42- Teval < :
43- AbstractEvaluationType
44- }
38+ loss,
39+ gradF,
40+ x0,
41+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
42+ Teval < :
43+ AbstractEvaluationType
44+ }
4545 sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
4646 opts = gradient_descent (opt. M,
47- loss,
48- gradF,
49- x0;
50- return_state = true ,
51- evaluation = Teval (),
52- stepsize = opt. stepsize,
53- sckwarg... )
47+ loss,
48+ gradF,
49+ x0;
50+ return_state = true ,
51+ evaluation = Teval (),
52+ stepsize = opt. stepsize,
53+ sckwarg... )
5454 # we unwrap DebugOptions here
5555 minimizer = Manopt. get_solver_result (opts)
5656 return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
57- :who_knows
57+ :who_knows
5858end
5959
6060# # Nelder-Mead
6161
6262struct NelderMeadOptimizer{
63- TM <: AbstractManifold ,
64- } <: AbstractManoptOptimizer
63+ TM <: AbstractManifold ,
64+ } <: AbstractManoptOptimizer
6565 M:: TM
6666end
6767
68-
6968function call_manopt_optimizer (opt:: NelderMeadOptimizer ,
70- loss,
71- gradF,
72- x0,
73- stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} )
69+ loss,
70+ gradF,
71+ x0,
72+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} )
7473 sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
7574
7675 opts = NelderMead (opt. M,
77- loss;
78- return_state = true ,
79- sckwarg... )
76+ loss;
77+ return_state = true ,
78+ sckwarg... )
8079 minimizer = Manopt. get_solver_result (opts)
8180 return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
82- :who_knows
81+ :who_knows
8382end
8483
8584# # conjugate gradient descent
8685
8786struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType ,
88- TM <: AbstractManifold , TLS <: Stepsize } < :
87+ TM <: AbstractManifold , TLS <: Stepsize } < :
8988 AbstractManoptOptimizer
9089 M:: TM
9190 stepsize:: TLS
9291end
9392
9493function ConjugateGradientDescentOptimizer (M:: AbstractManifold ;
95- eval:: AbstractEvaluationType = InplaceEvaluation (),
96- stepsize:: Stepsize = ArmijoLinesearch (M))
94+ eval:: AbstractEvaluationType = InplaceEvaluation (),
95+ stepsize:: Stepsize = ArmijoLinesearch (M))
9796 ConjugateGradientDescentOptimizer {typeof(eval), typeof(M), typeof(stepsize)} (M,
98- stepsize)
97+ stepsize)
9998end
10099
101100function call_manopt_optimizer (opt:: ConjugateGradientDescentOptimizer{Teval} ,
102- loss,
103- gradF,
104- x0,
105- stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
106- Teval < :
107- AbstractEvaluationType
108- }
101+ loss,
102+ gradF,
103+ x0,
104+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
105+ Teval < :
106+ AbstractEvaluationType
107+ }
109108 sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
110109 opts = conjugate_gradient_descent (opt. M,
111- loss,
112- gradF,
113- x0;
114- return_state = true ,
115- evaluation = Teval (),
116- stepsize = opt. stepsize,
117- sckwarg... )
110+ loss,
111+ gradF,
112+ x0;
113+ return_state = true ,
114+ evaluation = Teval (),
115+ stepsize = opt. stepsize,
116+ sckwarg... )
118117 # we unwrap DebugOptions here
119118 minimizer = Manopt. get_solver_result (opts)
120119 return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
121- :who_knows
120+ :who_knows
122121end
123122
124123# # particle swarm
125124
126125struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType ,
127- TM <: AbstractManifold , Tretr <: AbstractRetractionMethod ,
128- Tinvretr <: AbstractInverseRetractionMethod ,
129- Tvt <: AbstractVectorTransportMethod } < :
126+ TM <: AbstractManifold , Tretr <: AbstractRetractionMethod ,
127+ Tinvretr <: AbstractInverseRetractionMethod ,
128+ Tvt <: AbstractVectorTransportMethod } < :
130129 AbstractManoptOptimizer
131130 M:: TM
132131 retraction_method:: Tretr
@@ -136,50 +135,50 @@ struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
136135end
137136
138137function ParticleSwarmOptimizer (M:: AbstractManifold ;
139- eval:: AbstractEvaluationType = InplaceEvaluation (),
140- population_size:: Int = 100 ,
141- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
142- inverse_retraction_method:: AbstractInverseRetractionMethod = default_inverse_retraction_method (M),
143- vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M))
138+ eval:: AbstractEvaluationType = InplaceEvaluation (),
139+ population_size:: Int = 100 ,
140+ retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
141+ inverse_retraction_method:: AbstractInverseRetractionMethod = default_inverse_retraction_method (M),
142+ vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M))
144143 ParticleSwarmOptimizer{typeof (eval), typeof (M), typeof (retraction_method),
145- typeof (inverse_retraction_method),
146- typeof (vector_transport_method)}(M,
147- retraction_method,
148- inverse_retraction_method,
149- vector_transport_method,
150- population_size)
144+ typeof (inverse_retraction_method),
145+ typeof (vector_transport_method)}(M,
146+ retraction_method,
147+ inverse_retraction_method,
148+ vector_transport_method,
149+ population_size)
151150end
152151
153152function call_manopt_optimizer (opt:: ParticleSwarmOptimizer{Teval} ,
154- loss,
155- gradF,
156- x0,
157- stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
158- Teval < :
159- AbstractEvaluationType
160- }
153+ loss,
154+ gradF,
155+ x0,
156+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
157+ Teval < :
158+ AbstractEvaluationType
159+ }
161160 sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
162161 initial_population = vcat ([x0], [rand (opt. M) for _ in 1 : (opt. population_size - 1 )])
163162 opts = particle_swarm (opt. M,
164- loss;
165- x0 = initial_population,
166- n = opt. population_size,
167- return_state = true ,
168- retraction_method = opt. retraction_method,
169- inverse_retraction_method = opt. inverse_retraction_method,
170- vector_transport_method = opt. vector_transport_method,
171- sckwarg... )
163+ loss;
164+ x0 = initial_population,
165+ n = opt. population_size,
166+ return_state = true ,
167+ retraction_method = opt. retraction_method,
168+ inverse_retraction_method = opt. inverse_retraction_method,
169+ vector_transport_method = opt. vector_transport_method,
170+ sckwarg... )
172171 # we unwrap DebugOptions here
173172 minimizer = Manopt. get_solver_result (opts)
174173 return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
175- :who_knows
174+ :who_knows
176175end
177176
178177# # quasi Newton
179178
180179struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType ,
181- TM <: AbstractManifold , Tretr <: AbstractRetractionMethod ,
182- Tvt <: AbstractVectorTransportMethod , TLS <: Stepsize } < :
180+ TM <: AbstractManifold , Tretr <: AbstractRetractionMethod ,
181+ Tvt <: AbstractVectorTransportMethod , TLS <: Stepsize } < :
183182 AbstractManoptOptimizer
184183 M:: TM
185184 retraction_method:: Tretr
@@ -188,43 +187,43 @@ struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
188187end
189188
190189function QuasiNewtonOptimizer (M:: AbstractManifold ;
191- eval:: AbstractEvaluationType = InplaceEvaluation (),
192- retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
193- vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
194- stepsize = WolfePowellLinesearch (M;
195- retraction_method = retraction_method,
196- vector_transport_method = vector_transport_method,
197- linesearch_stopsize = 1e-12 ))
190+ eval:: AbstractEvaluationType = InplaceEvaluation (),
191+ retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
192+ vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
193+ stepsize = WolfePowellLinesearch (M;
194+ retraction_method = retraction_method,
195+ vector_transport_method = vector_transport_method,
196+ linesearch_stopsize = 1e-12 ))
198197 QuasiNewtonOptimizer{typeof (eval), typeof (M), typeof (retraction_method),
199- typeof (vector_transport_method), typeof (stepsize)}(M,
200- retraction_method,
201- vector_transport_method,
202- stepsize)
198+ typeof (vector_transport_method), typeof (stepsize)}(M,
199+ retraction_method,
200+ vector_transport_method,
201+ stepsize)
203202end
204203
205204function call_manopt_optimizer (opt:: QuasiNewtonOptimizer{Teval} ,
206- loss,
207- gradF,
208- x0,
209- stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
210- Teval < :
211- AbstractEvaluationType
212- }
205+ loss,
206+ gradF,
207+ x0,
208+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
209+ Teval < :
210+ AbstractEvaluationType
211+ }
213212 sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
214213 opts = quasi_Newton (opt. M,
215- loss,
216- gradF,
217- x0;
218- return_state = true ,
219- evaluation = Teval (),
220- retraction_method = opt. retraction_method,
221- vector_transport_method = opt. vector_transport_method,
222- stepsize = opt. stepsize,
223- sckwarg... )
214+ loss,
215+ gradF,
216+ x0;
217+ return_state = true ,
218+ evaluation = Teval (),
219+ retraction_method = opt. retraction_method,
220+ vector_transport_method = opt. vector_transport_method,
221+ stepsize = opt. stepsize,
222+ sckwarg... )
224223 # we unwrap DebugOptions here
225224 minimizer = Manopt. get_solver_result (opts)
226225 return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
227- :who_knows
226+ :who_knows
228227end
229228
230229# # Optimization.jl stuff
@@ -255,15 +254,15 @@ end
255254# 3) add callbacks to Manopt.jl
256255
257256function SciMLBase. __solve (prob:: OptimizationProblem ,
258- opt:: AbstractManoptOptimizer ,
259- data = Optimization. DEFAULT_DATA;
260- callback = (args... ) -> (false ),
261- maxiters:: Union{Number, Nothing} = nothing ,
262- maxtime:: Union{Number, Nothing} = nothing ,
263- abstol:: Union{Number, Nothing} = nothing ,
264- reltol:: Union{Number, Nothing} = nothing ,
265- progress = false ,
266- kwargs... )
257+ opt:: AbstractManoptOptimizer ,
258+ data = Optimization. DEFAULT_DATA;
259+ callback = (args... ) -> (false ),
260+ maxiters:: Union{Number, Nothing} = nothing ,
261+ maxtime:: Union{Number, Nothing} = nothing ,
262+ abstol:: Union{Number, Nothing} = nothing ,
263+ reltol:: Union{Number, Nothing} = nothing ,
264+ progress = false ,
265+ kwargs... )
267266 local x, cur, state
268267
269268 manifold = haskey (prob. kwargs, :manifold ) ? prob. kwargs[:manifold ] : nothing
@@ -295,12 +294,12 @@ function SciMLBase.__solve(prob::OptimizationProblem,
295294 opt_res, opt_ret = call_manopt_optimizer (opt, _loss, gradF, prob. u0, stopping_criterion)
296295
297296 return SciMLBase. build_solution (SciMLBase. DefaultOptimizationCache (prob. f, prob. p),
298- opt,
299- opt_res. minimizer,
300- prob. sense === Optimization. MaxSense ?
301- - opt_res. minimum : opt_res. minimum;
302- original = opt_res. options,
303- retcode = opt_ret)
297+ opt,
298+ opt_res. minimizer,
299+ prob. sense === Optimization. MaxSense ?
300+ - opt_res. minimum : opt_res. minimum;
301+ original = opt_res. options,
302+ retcode = opt_ret)
304303end
305304
306- end # module OptimizationManopt
305+ end # module OptimizationManopt
0 commit comments