1+ module OptimizationManopt
2+
3+ using Optimization, Manopt, ManifoldsBase
4+
5+ """
6+ abstract type AbstractManoptOptimizer end
7+
8+ A Manopt solver without things specified by a call to `solve` (stopping criteria) and
9+ internal state.
10+ """
11+ abstract type AbstractManoptOptimizer end
12+
13+ function stopping_criterion_to_kwarg (stopping_criterion:: Nothing )
14+ return NamedTuple ()
15+ end
16+ function stopping_criterion_to_kwarg (stopping_criterion:: StoppingCriterion )
17+ return (; stopping_criterion = stopping_criterion)
18+ end
19+
20+ # # gradient descent
21+
22+ struct GradientDescentOptimizer{
23+ Teval <: AbstractEvaluationType ,
24+ TM <: AbstractManifold ,
25+ TLS <: Linesearch
26+ } <: AbstractManoptOptimizer
27+ M:: TM
28+ stepsize:: TLS
29+ end
30+
31+ function GradientDescentOptimizer (M:: AbstractManifold ;
32+ eval:: AbstractEvaluationType = MutatingEvaluation (),
33+ stepsize:: Stepsize = ArmijoLinesearch (M))
34+ GradientDescentOptimizer {typeof(eval), typeof(M), typeof(stepsize)} (M, stepsize)
35+ end
36+
37+ function call_manopt_optimizer (opt:: GradientDescentOptimizer{Teval} ,
38+ loss,
39+ gradF,
40+ x0,
41+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
42+ Teval < :
43+ AbstractEvaluationType
44+ }
45+ sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
46+ opts = gradient_descent (opt. M,
47+ loss,
48+ gradF,
49+ x0;
50+ return_options = true ,
51+ evaluation = Teval (),
52+ stepsize = opt. stepsize,
53+ sckwarg... )
54+ # we unwrap DebugOptions here
55+ minimizer = Manopt. get_solver_result (opts)
56+ return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
57+ :who_knows
58+ end
59+
60+ # # Nelder-Mead
61+
62+ struct NelderMeadOptimizer{
63+ TM <: AbstractManifold ,
64+ Tpop <: AbstractVector
65+ } <: AbstractManoptOptimizer
66+ M:: TM
67+ initial_population:: Tpop
68+ end
69+
70+ function NelderMeadOptimizer (M:: AbstractManifold )
71+ initial_population = [rand (M) for _ in 1 : (manifold_dimension (M) + 1 )]
72+ return NelderMeadOptimizer {typeof(M), typeof(initial_population)} (M, initial_population)
73+ end
74+
75+ function call_manopt_optimizer (opt:: NelderMeadOptimizer ,
76+ loss,
77+ gradF,
78+ x0,
79+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} )
80+ sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
81+
82+ opts = NelderMead (opt. M,
83+ loss,
84+ opt. initial_population;
85+ return_options = true ,
86+ sckwarg... )
87+ minimizer = Manopt. get_solver_result (opts)
88+ return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
89+ :who_knows
90+ end
91+
92+ # # conjugate gradient descent
93+
94+ struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType ,
95+ TM <: AbstractManifold , TLS <: Stepsize } < :
96+ AbstractManoptOptimizer
97+ M:: TM
98+ stepsize:: TLS
99+ end
100+
101+ function ConjugateGradientDescentOptimizer (M:: AbstractManifold ;
102+ eval:: AbstractEvaluationType = MutatingEvaluation (),
103+ stepsize:: Stepsize = ArmijoLinesearch (M))
104+ ConjugateGradientDescentOptimizer {typeof(eval), typeof(M), typeof(stepsize)} (M,
105+ stepsize)
106+ end
107+
108+ function call_manopt_optimizer (opt:: ConjugateGradientDescentOptimizer{Teval} ,
109+ loss,
110+ gradF,
111+ x0,
112+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
113+ Teval < :
114+ AbstractEvaluationType
115+ }
116+ sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
117+ opts = conjugate_gradient_descent (opt. M,
118+ loss,
119+ gradF,
120+ x0;
121+ return_options = true ,
122+ evaluation = Teval (),
123+ stepsize = opt. stepsize,
124+ sckwarg... )
125+ # we unwrap DebugOptions here
126+ minimizer = Manopt. get_solver_result (opts)
127+ return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
128+ :who_knows
129+ end
130+
131+ # # particle swarm
132+
133+ struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType ,
134+ TM <: AbstractManifold , Tretr <: AbstractRetractionMethod ,
135+ Tinvretr <: AbstractInverseRetractionMethod ,
136+ Tvt <: AbstractVectorTransportMethod } < :
137+ AbstractManoptOptimizer
138+ M:: TM
139+ retraction_method:: Tretr
140+ inverse_retraction_method:: Tinvretr
141+ vector_transport_method:: Tvt
142+ population_size:: Int
143+ end
144+
145+ function ParticleSwarmOptimizer (M:: AbstractManifold ;
146+ eval:: AbstractEvaluationType = MutatingEvaluation (),
147+ population_size:: Int = 100 ,
148+ retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
149+ inverse_retraction_method:: AbstractInverseRetractionMethod = default_inverse_retraction_method (M),
150+ vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M))
151+ ParticleSwarmOptimizer{typeof (eval), typeof (M), typeof (retraction_method),
152+ typeof (inverse_retraction_method),
153+ typeof (vector_transport_method)}(M,
154+ retraction_method,
155+ inverse_retraction_method,
156+ vector_transport_method,
157+ population_size)
158+ end
159+
160+ function call_manopt_optimizer (opt:: ParticleSwarmOptimizer{Teval} ,
161+ loss,
162+ gradF,
163+ x0,
164+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
165+ Teval < :
166+ AbstractEvaluationType
167+ }
168+ sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
169+ initial_population = vcat ([x0], [rand (opt. M) for _ in 1 : (opt. population_size - 1 )])
170+ opts = particle_swarm (opt. M,
171+ loss;
172+ x0 = initial_population,
173+ n = opt. population_size,
174+ return_options = true ,
175+ retraction_method = opt. retraction_method,
176+ inverse_retraction_method = opt. inverse_retraction_method,
177+ vector_transport_method = opt. vector_transport_method,
178+ sckwarg... )
179+ # we unwrap DebugOptions here
180+ minimizer = Manopt. get_solver_result (opts)
181+ return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
182+ :who_knows
183+ end
184+
185+ # # quasi Newton
186+
187+ struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType ,
188+ TM <: AbstractManifold , Tretr <: AbstractRetractionMethod ,
189+ Tvt <: AbstractVectorTransportMethod , TLS <: Stepsize } < :
190+ AbstractManoptOptimizer
191+ M:: TM
192+ retraction_method:: Tretr
193+ vector_transport_method:: Tvt
194+ stepsize:: TLS
195+ end
196+
197+ function QuasiNewtonOptimizer (M:: AbstractManifold ;
198+ eval:: AbstractEvaluationType = MutatingEvaluation (),
199+ retraction_method:: AbstractRetractionMethod = default_retraction_method (M),
200+ vector_transport_method:: AbstractVectorTransportMethod = default_vector_transport_method (M),
201+ stepsize = WolfePowellLinesearch (M;
202+ retraction_method = retraction_method,
203+ vector_transport_method = vector_transport_method,
204+ linesearch_stopsize = 1e-12 ))
205+ QuasiNewtonOptimizer{typeof (eval), typeof (M), typeof (retraction_method),
206+ typeof (vector_transport_method), typeof (stepsize)}(M,
207+ retraction_method,
208+ vector_transport_method,
209+ stepsize)
210+ end
211+
212+ function call_manopt_optimizer (opt:: QuasiNewtonOptimizer{Teval} ,
213+ loss,
214+ gradF,
215+ x0,
216+ stopping_criterion:: Union{Nothing, Manopt.StoppingCriterion} ) where {
217+ Teval < :
218+ AbstractEvaluationType
219+ }
220+ sckwarg = stopping_criterion_to_kwarg (stopping_criterion)
221+ opts = quasi_Newton (opt. M,
222+ loss,
223+ gradF,
224+ x0;
225+ return_options = true ,
226+ evaluation = Teval (),
227+ retraction_method = opt. retraction_method,
228+ vector_transport_method = opt. vector_transport_method,
229+ stepsize = opt. stepsize,
230+ sckwarg... )
231+ # we unwrap DebugOptions here
232+ minimizer = Manopt. get_solver_result (opts)
233+ return (; minimizer = minimizer, minimum = loss (opt. M, minimizer), options = opts),
234+ :who_knows
235+ end
236+
237+ # # Optimization.jl stuff
238+
239+ function build_loss (f:: OptimizationFunction , prob, cur)
240+ function (:: AbstractManifold , θ)
241+ x = f. f (θ, prob. p, cur... )
242+ __x = first (x)
243+ return prob. sense === Optimization. MaxSense ? - __x : __x
244+ end
245+ end
246+
247+ function build_gradF (f:: OptimizationFunction{true} , prob, cur)
248+ function (:: AbstractManifold , G, θ)
249+ X = f. grad (G, θ, cur... )
250+ if prob. sense === Optimization. MaxSense
251+ return - X # TODO : check
252+ else
253+ return X
254+ end
255+ end
256+ end
257+
258+ # TODO :
259+ # 1) convert tolerances and other stopping criteria
260+ # 2) return convergence information
261+ # 3) add callbacks to Manopt.jl
262+
263+ function SciMLBase. __solve (prob:: OptimizationProblem ,
264+ opt:: AbstractManoptOptimizer ,
265+ data = Optimization. DEFAULT_DATA;
266+ callback = (args... ) -> (false ),
267+ maxiters:: Union{Number, Nothing} = nothing ,
268+ maxtime:: Union{Number, Nothing} = nothing ,
269+ abstol:: Union{Number, Nothing} = nothing ,
270+ reltol:: Union{Number, Nothing} = nothing ,
271+ progress = false ,
272+ kwargs... )
273+ local x, cur, state
274+
275+ if data != = Optimization. DEFAULT_DATA
276+ maxiters = length (data)
277+ end
278+
279+ cur, state = iterate (data)
280+
281+ stopping_criterion = nothing
282+ if maxiters != = nothing
283+ stopping_criterion = StopAfterIteration (maxiters)
284+ end
285+
286+ maxiters = Optimization. _check_and_convert_maxiters (maxiters)
287+ maxtime = Optimization. _check_and_convert_maxtime (maxtime)
288+
289+ f = Optimization. instantiate_function (prob. f, prob. u0, prob. f. adtype, prob. p)
290+
291+ _loss = build_loss (f, prob, cur)
292+
293+ gradF = build_gradF (f, prob, cur)
294+
295+ opt_res, opt_ret = call_manopt_optimizer (opt, _loss, gradF, prob. u0, stopping_criterion)
296+
297+ return SciMLBase. build_solution (SciMLBase. DefaultOptimizationCache (prob. f, prob. p),
298+ opt,
299+ opt_res. minimizer,
300+ prob. sense === Optimization. MaxSense ?
301+ - opt_res. minimum : opt_res. minimum;
302+ original = opt_res. options,
303+ retcode = opt_ret)
304+ end
305+
306+ end # module OptimizationManopt
0 commit comments