Skip to content

Commit f9e94f7

Browse files
Merge pull request #1009 from kellertuer/kellertuer/properManopt
Improve the OptimizationManopt.jl interface
2 parents 2f96f18 + 2e46ca1 commit f9e94f7

File tree

5 files changed

+189
-339
lines changed

5 files changed

+189
-339
lines changed

docs/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ IterTools = "1"
5959
Juniper = "0.9"
6060
Lux = "1"
6161
MLUtils = "0.4.4"
62-
Manifolds = "0.9"
63-
Manopt = "0.4"
62+
Manifolds = "0.10"
63+
Manopt = "0.5"
6464
ModelingToolkit = "10.23"
6565
NLPModels = "0.21"
6666
NLPModelsTest = "0.10"
@@ -73,7 +73,7 @@ OptimizationEvolutionary = "0.4"
7373
OptimizationGCMAES = "0.3"
7474
OptimizationIpopt = "0.2"
7575
OptimizationMOI = "0.5"
76-
OptimizationManopt = "0.0.5"
76+
OptimizationManopt = "0.1.0"
7777
OptimizationMetaheuristics = "0.3"
7878
OptimizationNLPModels = "0.0.2"
7979
OptimizationNLopt = "0.3"

docs/src/optimization_packages/manopt.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Manopt.jl
22

3-
[Manopt.jl](https://github.com/JuliaManifolds/Manopt.jl) is a package with implementations of a variety of optimization solvers on manifolds supported by
4-
[Manifolds](https://github.com/JuliaManifolds/Manifolds.jl).
3+
[Manopt.jl](https://github.com/JuliaManifolds/Manopt.jl) is a package providing solvers
4+
for optimization problems defined on Riemannian manifolds.
5+
The implementation is based on [ManifoldsBase.jl](https://github.com/JuliaManifolds/ManifoldsBase.jl) interface and can hence be used for all maniolds defined in
6+
[Manifolds](https://github.com/JuliaManifolds/Manifolds.jl) or any other manifold implemented using the interface.
57

68
## Installation: OptimizationManopt.jl
79

@@ -29,7 +31,7 @@ The common kwargs `maxiters`, `maxtime` and `abstol` are supported by all the op
2931
function or `OptimizationProblem`.
3032

3133
!!! note
32-
34+
3335
The `OptimizationProblem` has to be passed the manifold as the `manifold` keyword argument.
3436

3537
## Examples

lib/OptimizationManopt/Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationManopt"
22
uuid = "e57b7fff-7ee7-4550-b4f0-90e9476e9fb6"
3-
authors = ["Mateusz Baran <[email protected]>"]
4-
version = "0.0.5"
3+
authors = ["Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>"]
4+
version = "0.1.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -14,10 +14,10 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1414

1515
[compat]
1616
LinearAlgebra = "1.10"
17-
ManifoldDiff = "0.3.10"
18-
Manifolds = "0.9.18"
19-
ManifoldsBase = "0.15.10"
20-
Manopt = "0.4.63"
17+
ManifoldDiff = "0.4"
18+
Manifolds = "0.10"
19+
ManifoldsBase = "1"
20+
Manopt = "0.5"
2121
Optimization = "4.4"
2222
Reexport = "1.2"
2323
julia = "1.10"

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 44 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,15 @@ function call_manopt_optimizer(
7070
loss,
7171
gradF,
7272
x0;
73-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
74-
evaluation::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
75-
stepsize::Stepsize = ArmijoLinesearch(M),
73+
hessF=nothing, # ignore that keyword for this solver
7674
kwargs...)
77-
opts = gradient_descent(M,
75+
opts = Manopt.gradient_descent(M,
7876
loss,
7977
gradF,
8078
x0;
81-
return_state = true,
82-
evaluation,
83-
stepsize,
84-
stopping_criterion,
85-
kwargs...)
86-
# we unwrap DebugOptions here
79+
return_state = true, # return the (full, decorated) solver state
80+
kwargs...
81+
)
8782
minimizer = Manopt.get_solver_result(opts)
8883
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
8984
end
@@ -95,13 +90,9 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMea
9590
loss,
9691
gradF,
9792
x0;
98-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
99-
kwargs...)
100-
opts = NelderMead(M,
101-
loss;
102-
return_state = true,
103-
stopping_criterion,
93+
hessF=nothing, # ignore that keyword for this solver
10494
kwargs...)
95+
opts = NelderMead(M, loss; return_state = true, kwargs...)
10596
minimizer = Manopt.get_solver_result(opts)
10697
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
10798
end
@@ -114,20 +105,15 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
114105
loss,
115106
gradF,
116107
x0;
117-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
118-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
119-
stepsize::Stepsize = ArmijoLinesearch(M),
108+
hessF=nothing, # ignore that keyword for this solver
120109
kwargs...)
121-
opts = conjugate_gradient_descent(M,
110+
opts = Manopt.conjugate_gradient_descent(M,
122111
loss,
123112
gradF,
124113
x0;
125114
return_state = true,
126-
evaluation,
127-
stepsize,
128-
stopping_criterion,
129-
kwargs...)
130-
# we unwrap DebugOptions here
115+
kwargs...
116+
)
131117
minimizer = Manopt.get_solver_result(opts)
132118
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
133119
end
@@ -140,25 +126,11 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
140126
loss,
141127
gradF,
142128
x0;
143-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
144-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
129+
hessF=nothing, # ignore that keyword for this solver
145130
population_size::Int = 100,
146-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
147-
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
148-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
149131
kwargs...)
150-
initial_population = vcat([x0], [rand(M) for _ in 1:(population_size - 1)])
151-
opts = particle_swarm(M,
152-
loss;
153-
x0 = initial_population,
154-
n = population_size,
155-
return_state = true,
156-
retraction_method,
157-
inverse_retraction_method,
158-
vector_transport_method,
159-
stopping_criterion,
160-
kwargs...)
161-
# we unwrap DebugOptions here
132+
swarm = [x0, [rand(M) for _ in 1:(population_size - 1)]...]
133+
opts = particle_swarm(M, loss, swarm; return_state = true, kwargs...)
162134
minimizer = Manopt.get_solver_result(opts)
163135
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
164136
end
@@ -172,28 +144,10 @@ function call_manopt_optimizer(M::Manopt.AbstractManifold,
172144
loss,
173145
gradF,
174146
x0;
175-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
176-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
177-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
178-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
179-
stepsize = WolfePowellLinesearch(M;
180-
retraction_method = retraction_method,
181-
vector_transport_method = vector_transport_method,
182-
linesearch_stopsize = 1e-12),
147+
hessF=nothing, # ignore that keyword for this solver
183148
kwargs...
184149
)
185-
opts = quasi_Newton(M,
186-
loss,
187-
gradF,
188-
x0;
189-
return_state = true,
190-
evaluation,
191-
retraction_method,
192-
vector_transport_method,
193-
stepsize,
194-
stopping_criterion,
195-
kwargs...)
196-
# we unwrap DebugOptions here
150+
opts = quasi_Newton(M, loss, gradF, x0; return_state = true, kwargs...)
197151
minimizer = Manopt.get_solver_result(opts)
198152
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
199153
end
@@ -205,19 +159,9 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
205159
loss,
206160
gradF,
207161
x0;
208-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
209-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
210-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
211-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
212-
basis = Manopt.DefaultOrthonormalBasis(),
213-
kwargs...)
214-
opt = cma_es(M,
215-
loss,
216-
x0;
217-
return_state = true,
218-
stopping_criterion,
162+
hessF=nothing, # ignore that keyword for this solver
219163
kwargs...)
220-
# we unwrap DebugOptions here
164+
opt = cma_es(M, loss, x0; return_state = true, kwargs...)
221165
minimizer = Manopt.get_solver_result(opt)
222166
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
223167
end
@@ -229,22 +173,9 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
229173
loss,
230174
gradF,
231175
x0;
232-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
233-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
234-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
235-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
236-
kwargs...)
237-
opt = convex_bundle_method!(M,
238-
loss,
239-
gradF,
240-
x0;
241-
return_state = true,
242-
evaluation,
243-
retraction_method,
244-
vector_transport_method,
245-
stopping_criterion,
176+
hessF=nothing, # ignore that keyword for this solver
246177
kwargs...)
247-
# we unwrap DebugOptions here
178+
opt = convex_bundle_method(M, loss, gradF, x0; return_state = true, kwargs...)
248179
minimizer = Manopt.get_solver_result(opt)
249180
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
250181
end
@@ -257,21 +188,13 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
257188
gradF,
258189
x0;
259190
hessF = nothing,
260-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
261-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
262-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
263-
kwargs...)
264-
opt = adaptive_regularization_with_cubics(M,
265-
loss,
266-
gradF,
267-
hessF,
268-
x0;
269-
return_state = true,
270-
evaluation,
271-
retraction_method,
272-
stopping_criterion,
273191
kwargs...)
274-
# we unwrap DebugOptions here
192+
193+
opt = if isnothing(hessF)
194+
adaptive_regularization_with_cubics(M, loss, gradF, x0; return_state = true, kwargs...)
195+
else
196+
adaptive_regularization_with_cubics(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
197+
end
275198
minimizer = Manopt.get_solver_result(opt)
276199
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
277200
end
@@ -284,21 +207,12 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
284207
gradF,
285208
x0;
286209
hessF = nothing,
287-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
288-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
289-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
290-
kwargs...)
291-
opt = trust_regions(M,
292-
loss,
293-
gradF,
294-
hessF,
295-
x0;
296-
return_state = true,
297-
evaluation,
298-
retraction = retraction_method,
299-
stopping_criterion,
300210
kwargs...)
301-
# we unwrap DebugOptions here
211+
opt = if isnothing(hessF)
212+
trust_regions(M, loss, gradF, x0; return_state = true, kwargs...)
213+
else
214+
trust_regions(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
215+
end
302216
minimizer = Manopt.get_solver_result(opt)
303217
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
304218
end
@@ -310,22 +224,9 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
310224
loss,
311225
gradF,
312226
x0;
313-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
314-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
315-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
316-
stepsize::Stepsize = DecreasingStepsize(; length = 2.0, shift = 2),
317-
kwargs...)
318-
opt = Frank_Wolfe_method(M,
319-
loss,
320-
gradF,
321-
x0;
322-
return_state = true,
323-
evaluation,
324-
retraction_method,
325-
stopping_criterion,
326-
stepsize,
227+
hessF=nothing, # ignore that keyword for this solver
327228
kwargs...)
328-
# we unwrap DebugOptions here
229+
opt = Frank_Wolfe_method(M, loss, gradF, x0; return_state = true, kwargs...)
329230
minimizer = Manopt.get_solver_result(opt)
330231
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
331232
end
@@ -339,11 +240,14 @@ function SciMLBase.requiresgradient(opt::Union{
339240
end
340241
function SciMLBase.requireshessian(opt::Union{
341242
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
342-
true
243+
false
343244
end
344245

345246
function build_loss(f::OptimizationFunction, prob, cb)
346-
function (::AbstractManifold, θ)
247+
# TODO: I do not understand this. Why is the manifold not used?
248+
# Either this is an Euclidean cost, then we should probably still call `embed`,
249+
# or it is not, then we need M.
250+
return function (::AbstractManifold, θ)
347251
x = f.f(θ, prob.p)
348252
cb(x, θ)
349253
__x = first(x)
@@ -361,6 +265,7 @@ function build_gradF(f::OptimizationFunction{true})
361265
f.grad(G, θ)
362266
return riemannian_gradient(M, θ, G)
363267
end
268+
return g
364269
end
365270

366271
function build_hessF(f::OptimizationFunction{true})
@@ -372,12 +277,13 @@ function build_hessF(f::OptimizationFunction{true})
372277
riemannian_Hessian!(M, H1, θ, G, H, X)
373278
end
374279
function h(M::AbstractManifold, θ, X)
375-
H = zeros(eltype(θ), length(θ), length(θ))
376-
f.hess(H, θ)
280+
H = zeros(eltype(θ), length(θ))
281+
f.hv(H, θ, X)
377282
G = zeros(eltype(θ), length(θ))
378283
f.grad(G, θ)
379284
return riemannian_Hessian(M, θ, G, H, X)
380285
end
286+
return h
381287
end
382288

383289
function SciMLBase.__solve(cache::OptimizationCache{
@@ -400,8 +306,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
400306
LC,
401307
UC,
402308
S,
403-
O <:
404-
AbstractManoptOptimizer,
309+
O <: AbstractManoptOptimizer,
405310
D,
406311
P,
407312
C
@@ -457,7 +362,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
457362
solver_kwarg..., stopping_criterion = stopping_criterion, hessF)
458363

459364
asc = get_stopping_criterion(opt_res.options)
460-
opt_ret = Manopt.indicates_convergence(asc) ? ReturnCode.Success : ReturnCode.Failure
365+
opt_ret = Manopt.has_converged(asc) ? ReturnCode.Success : ReturnCode.Failure
461366

462367
return SciMLBase.build_solution(cache,
463368
cache.opt,

0 commit comments

Comments
 (0)