Skip to content

Commit 136606b

Browse files
Move code from existing PR
1 parent 038c7b6 commit 136606b

File tree

3 files changed

+417
-0
lines changed

3 files changed

+417
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
name = "OptimizationManopt"
2+
uuid = "e57b7fff-7ee7-4550-b4f0-90e9476e9fb6"
3+
authors = ["Mateusz Baran <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
8+
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
9+
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
10+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
module OptimizationManopt
2+
3+
using Optimization, Manopt, ManifoldsBase
4+
5+
"""
6+
abstract type AbstractManoptOptimizer end
7+
8+
A Manopt solver without things specified by a call to `solve` (stopping criteria) and
9+
internal state.
10+
"""
11+
abstract type AbstractManoptOptimizer end
12+
13+
function stopping_criterion_to_kwarg(stopping_criterion::Nothing)
14+
return NamedTuple()
15+
end
16+
function stopping_criterion_to_kwarg(stopping_criterion::StoppingCriterion)
17+
return (; stopping_criterion = stopping_criterion)
18+
end
19+
20+
## gradient descent
21+
22+
struct GradientDescentOptimizer{
23+
Teval <: AbstractEvaluationType,
24+
TM <: AbstractManifold,
25+
TLS <: Linesearch
26+
} <: AbstractManoptOptimizer
27+
M::TM
28+
stepsize::TLS
29+
end
30+
31+
function GradientDescentOptimizer(M::AbstractManifold;
32+
eval::AbstractEvaluationType = MutatingEvaluation(),
33+
stepsize::Stepsize = ArmijoLinesearch(M))
34+
GradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M, stepsize)
35+
end
36+
37+
function call_manopt_optimizer(opt::GradientDescentOptimizer{Teval},
38+
loss,
39+
gradF,
40+
x0,
41+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
42+
Teval <:
43+
AbstractEvaluationType
44+
}
45+
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
46+
opts = gradient_descent(opt.M,
47+
loss,
48+
gradF,
49+
x0;
50+
return_options = true,
51+
evaluation = Teval(),
52+
stepsize = opt.stepsize,
53+
sckwarg...)
54+
# we unwrap DebugOptions here
55+
minimizer = Manopt.get_solver_result(opts)
56+
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
57+
:who_knows
58+
end
59+
60+
## Nelder-Mead
61+
62+
struct NelderMeadOptimizer{
63+
TM <: AbstractManifold,
64+
Tpop <: AbstractVector
65+
} <: AbstractManoptOptimizer
66+
M::TM
67+
initial_population::Tpop
68+
end
69+
70+
function NelderMeadOptimizer(M::AbstractManifold)
71+
initial_population = [rand(M) for _ in 1:(manifold_dimension(M) + 1)]
72+
return NelderMeadOptimizer{typeof(M), typeof(initial_population)}(M, initial_population)
73+
end
74+
75+
function call_manopt_optimizer(opt::NelderMeadOptimizer,
76+
loss,
77+
gradF,
78+
x0,
79+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion})
80+
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
81+
82+
opts = NelderMead(opt.M,
83+
loss,
84+
opt.initial_population;
85+
return_options = true,
86+
sckwarg...)
87+
minimizer = Manopt.get_solver_result(opts)
88+
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
89+
:who_knows
90+
end
91+
92+
## conjugate gradient descent
93+
94+
struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType,
95+
TM <: AbstractManifold, TLS <: Stepsize} <:
96+
AbstractManoptOptimizer
97+
M::TM
98+
stepsize::TLS
99+
end
100+
101+
function ConjugateGradientDescentOptimizer(M::AbstractManifold;
102+
eval::AbstractEvaluationType = MutatingEvaluation(),
103+
stepsize::Stepsize = ArmijoLinesearch(M))
104+
ConjugateGradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M,
105+
stepsize)
106+
end
107+
108+
function call_manopt_optimizer(opt::ConjugateGradientDescentOptimizer{Teval},
109+
loss,
110+
gradF,
111+
x0,
112+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
113+
Teval <:
114+
AbstractEvaluationType
115+
}
116+
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
117+
opts = conjugate_gradient_descent(opt.M,
118+
loss,
119+
gradF,
120+
x0;
121+
return_options = true,
122+
evaluation = Teval(),
123+
stepsize = opt.stepsize,
124+
sckwarg...)
125+
# we unwrap DebugOptions here
126+
minimizer = Manopt.get_solver_result(opts)
127+
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
128+
:who_knows
129+
end
130+
131+
## particle swarm
132+
133+
struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
134+
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
135+
Tinvretr <: AbstractInverseRetractionMethod,
136+
Tvt <: AbstractVectorTransportMethod} <:
137+
AbstractManoptOptimizer
138+
M::TM
139+
retraction_method::Tretr
140+
inverse_retraction_method::Tinvretr
141+
vector_transport_method::Tvt
142+
population_size::Int
143+
end
144+
145+
function ParticleSwarmOptimizer(M::AbstractManifold;
146+
eval::AbstractEvaluationType = MutatingEvaluation(),
147+
population_size::Int = 100,
148+
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
149+
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
150+
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M))
151+
ParticleSwarmOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
152+
typeof(inverse_retraction_method),
153+
typeof(vector_transport_method)}(M,
154+
retraction_method,
155+
inverse_retraction_method,
156+
vector_transport_method,
157+
population_size)
158+
end
159+
160+
function call_manopt_optimizer(opt::ParticleSwarmOptimizer{Teval},
161+
loss,
162+
gradF,
163+
x0,
164+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
165+
Teval <:
166+
AbstractEvaluationType
167+
}
168+
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
169+
initial_population = vcat([x0], [rand(opt.M) for _ in 1:(opt.population_size - 1)])
170+
opts = particle_swarm(opt.M,
171+
loss;
172+
x0 = initial_population,
173+
n = opt.population_size,
174+
return_options = true,
175+
retraction_method = opt.retraction_method,
176+
inverse_retraction_method = opt.inverse_retraction_method,
177+
vector_transport_method = opt.vector_transport_method,
178+
sckwarg...)
179+
# we unwrap DebugOptions here
180+
minimizer = Manopt.get_solver_result(opts)
181+
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
182+
:who_knows
183+
end
184+
185+
## quasi Newton
186+
187+
struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
188+
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
189+
Tvt <: AbstractVectorTransportMethod, TLS <: Stepsize} <:
190+
AbstractManoptOptimizer
191+
M::TM
192+
retraction_method::Tretr
193+
vector_transport_method::Tvt
194+
stepsize::TLS
195+
end
196+
197+
function QuasiNewtonOptimizer(M::AbstractManifold;
198+
eval::AbstractEvaluationType = MutatingEvaluation(),
199+
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
200+
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
201+
stepsize = WolfePowellLinesearch(M;
202+
retraction_method = retraction_method,
203+
vector_transport_method = vector_transport_method,
204+
linesearch_stopsize = 1e-12))
205+
QuasiNewtonOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
206+
typeof(vector_transport_method), typeof(stepsize)}(M,
207+
retraction_method,
208+
vector_transport_method,
209+
stepsize)
210+
end
211+
212+
function call_manopt_optimizer(opt::QuasiNewtonOptimizer{Teval},
213+
loss,
214+
gradF,
215+
x0,
216+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
217+
Teval <:
218+
AbstractEvaluationType
219+
}
220+
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
221+
opts = quasi_Newton(opt.M,
222+
loss,
223+
gradF,
224+
x0;
225+
return_options = true,
226+
evaluation = Teval(),
227+
retraction_method = opt.retraction_method,
228+
vector_transport_method = opt.vector_transport_method,
229+
stepsize = opt.stepsize,
230+
sckwarg...)
231+
# we unwrap DebugOptions here
232+
minimizer = Manopt.get_solver_result(opts)
233+
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
234+
:who_knows
235+
end
236+
237+
## Optimization.jl stuff
238+
239+
function build_loss(f::OptimizationFunction, prob, cur)
240+
function (::AbstractManifold, θ)
241+
x = f.f(θ, prob.p, cur...)
242+
__x = first(x)
243+
return prob.sense === Optimization.MaxSense ? -__x : __x
244+
end
245+
end
246+
247+
function build_gradF(f::OptimizationFunction{true}, prob, cur)
248+
function (::AbstractManifold, G, θ)
249+
X = f.grad(G, θ, cur...)
250+
if prob.sense === Optimization.MaxSense
251+
return -X # TODO: check
252+
else
253+
return X
254+
end
255+
end
256+
end
257+
258+
# TODO:
259+
# 1) convert tolerances and other stopping criteria
260+
# 2) return convergence information
261+
# 3) add callbacks to Manopt.jl
262+
263+
function SciMLBase.__solve(prob::OptimizationProblem,
264+
opt::AbstractManoptOptimizer,
265+
data = Optimization.DEFAULT_DATA;
266+
callback = (args...) -> (false),
267+
maxiters::Union{Number, Nothing} = nothing,
268+
maxtime::Union{Number, Nothing} = nothing,
269+
abstol::Union{Number, Nothing} = nothing,
270+
reltol::Union{Number, Nothing} = nothing,
271+
progress = false,
272+
kwargs...)
273+
local x, cur, state
274+
275+
if data !== Optimization.DEFAULT_DATA
276+
maxiters = length(data)
277+
end
278+
279+
cur, state = iterate(data)
280+
281+
stopping_criterion = nothing
282+
if maxiters !== nothing
283+
stopping_criterion = StopAfterIteration(maxiters)
284+
end
285+
286+
maxiters = Optimization._check_and_convert_maxiters(maxiters)
287+
maxtime = Optimization._check_and_convert_maxtime(maxtime)
288+
289+
f = Optimization.instantiate_function(prob.f, prob.u0, prob.f.adtype, prob.p)
290+
291+
_loss = build_loss(f, prob, cur)
292+
293+
gradF = build_gradF(f, prob, cur)
294+
295+
opt_res, opt_ret = call_manopt_optimizer(opt, _loss, gradF, prob.u0, stopping_criterion)
296+
297+
return SciMLBase.build_solution(SciMLBase.DefaultOptimizationCache(prob.f, prob.p),
298+
opt,
299+
opt_res.minimizer,
300+
prob.sense === Optimization.MaxSense ?
301+
-opt_res.minimum : opt_res.minimum;
302+
original = opt_res.options,
303+
retcode = opt_ret)
304+
end
305+
306+
end # module OptimizationManopt

0 commit comments

Comments
 (0)