Skip to content

Commit a074422

Browse files
committed
Starts adapting and reworking to Manopt 0.5, Manifolds 0.10, ManifoldsBase 1.0
1 parent 302d6f1 commit a074422

File tree

2 files changed

+38
-133
lines changed

2 files changed

+38
-133
lines changed

lib/OptimizationManopt/Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1414

1515
[compat]
1616
LinearAlgebra = "1.10"
17-
ManifoldDiff = "0.3.10"
18-
Manifolds = "0.9.18"
19-
ManifoldsBase = "0.15.10"
20-
Manopt = "0.4.63"
17+
ManifoldDiff = "0.4"
18+
Manifolds = "0.10"
19+
ManifoldsBase = "1"
20+
Manopt = "0.5"
2121
Optimization = "4.4"
2222
Reexport = "1.2"
2323
julia = "1.10"

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 34 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,14 @@ function call_manopt_optimizer(
6565
loss,
6666
gradF,
6767
x0;
68-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
69-
evaluation::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
70-
stepsize::Stepsize = ArmijoLinesearch(M),
7168
kwargs...)
72-
opts = gradient_descent(M,
69+
opts = Manopt.gradient_descent(M,
7370
loss,
7471
gradF,
7572
x0;
76-
return_state = true,
77-
evaluation,
78-
stepsize,
79-
stopping_criterion,
80-
kwargs...)
81-
# we unwrap DebugOptions here
73+
return_state = true, # return the (full, decorated) solver state
74+
kwargs...
75+
)
8276
minimizer = Manopt.get_solver_result(opts)
8377
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
8478
end
@@ -90,13 +84,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMea
9084
loss,
9185
gradF,
9286
x0;
93-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
94-
kwargs...)
95-
opts = NelderMead(M,
96-
loss;
97-
return_state = true,
98-
stopping_criterion,
9987
kwargs...)
88+
opts = NelderMead(M, loss; return_state = true, kwargs...)
10089
minimizer = Manopt.get_solver_result(opts)
10190
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
10291
end
@@ -109,19 +98,14 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
10998
loss,
11099
gradF,
111100
x0;
112-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
113-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
114-
stepsize::Stepsize = ArmijoLinesearch(M),
115101
kwargs...)
116-
opts = conjugate_gradient_descent(M,
102+
opts = Manopt.conjugate_gradient_descent(M,
117103
loss,
118104
gradF,
119105
x0;
120106
return_state = true,
121-
evaluation,
122-
stepsize,
123-
stopping_criterion,
124-
kwargs...)
107+
kwargs...
108+
)
125109
# we unwrap DebugOptions here
126110
minimizer = Manopt.get_solver_result(opts)
127111
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
@@ -135,25 +119,10 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
135119
loss,
136120
gradF,
137121
x0;
138-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
139-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
140122
population_size::Int = 100,
141-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
142-
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
143-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
144-
kwargs...)
145-
initial_population = vcat([x0], [rand(M) for _ in 1:(population_size - 1)])
146-
opts = particle_swarm(M,
147-
loss;
148-
x0 = initial_population,
149-
n = population_size,
150-
return_state = true,
151-
retraction_method,
152-
inverse_retraction_method,
153-
vector_transport_method,
154-
stopping_criterion,
155123
kwargs...)
156-
# we unwrap DebugOptions here
124+
swarm = [x0, [rand(M) for _ in 1:(population_size - 1)]...]
125+
opts = particle_swarm(M, loss, swarm; return_state = true, kwargs...)
157126
minimizer = Manopt.get_solver_result(opts)
158127
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
159128
end
@@ -167,27 +136,9 @@ function call_manopt_optimizer(M::Manopt.AbstractManifold,
167136
loss,
168137
gradF,
169138
x0;
170-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
171-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
172-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
173-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
174-
stepsize = WolfePowellLinesearch(M;
175-
retraction_method = retraction_method,
176-
vector_transport_method = vector_transport_method,
177-
linesearch_stopsize = 1e-12),
178139
kwargs...
179140
)
180-
opts = quasi_Newton(M,
181-
loss,
182-
gradF,
183-
x0;
184-
return_state = true,
185-
evaluation,
186-
retraction_method,
187-
vector_transport_method,
188-
stepsize,
189-
stopping_criterion,
190-
kwargs...)
141+
opts = quasi_Newton(M, loss, gradF, x0; return_state = true, kwargs...)
191142
# we unwrap DebugOptions here
192143
minimizer = Manopt.get_solver_result(opts)
193144
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
@@ -200,18 +151,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
200151
loss,
201152
gradF,
202153
x0;
203-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
204-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
205-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
206-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
207-
basis = Manopt.DefaultOrthonormalBasis(),
208-
kwargs...)
209-
opt = cma_es(M,
210-
loss,
211-
x0;
212-
return_state = true,
213-
stopping_criterion,
214154
kwargs...)
155+
opt = cma_es(M, loss, x0; return_state = true, kwargs...)
215156
# we unwrap DebugOptions here
216157
minimizer = Manopt.get_solver_result(opt)
217158
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
@@ -224,21 +165,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
224165
loss,
225166
gradF,
226167
x0;
227-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
228-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
229-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
230-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
231-
kwargs...)
232-
opt = convex_bundle_method!(M,
233-
loss,
234-
gradF,
235-
x0;
236-
return_state = true,
237-
evaluation,
238-
retraction_method,
239-
vector_transport_method,
240-
stopping_criterion,
241168
kwargs...)
169+
opt = convex_bundle_method(M, loss, gradF, x0; return_state = true, kwargs...)
242170
# we unwrap DebugOptions here
243171
minimizer = Manopt.get_solver_result(opt)
244172
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
@@ -252,21 +180,13 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
252180
gradF,
253181
x0;
254182
hessF = nothing,
255-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
256-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
257-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
258-
kwargs...)
259-
opt = adaptive_regularization_with_cubics(M,
260-
loss,
261-
gradF,
262-
hessF,
263-
x0;
264-
return_state = true,
265-
evaluation,
266-
retraction_method,
267-
stopping_criterion,
268183
kwargs...)
269-
# we unwrap DebugOptions here
184+
185+
opt = if isnothing(hessF)
186+
adaptive_regularization_with_cubics(M, loss, gradF, x0; return_state = true, kwargs...)
187+
else
188+
adaptive_regularization_with_cubics(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
189+
end
270190
minimizer = Manopt.get_solver_result(opt)
271191
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
272192
end
@@ -279,20 +199,12 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
279199
gradF,
280200
x0;
281201
hessF = nothing,
282-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
283-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
284-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
285-
kwargs...)
286-
opt = trust_regions(M,
287-
loss,
288-
gradF,
289-
hessF,
290-
x0;
291-
return_state = true,
292-
evaluation,
293-
retraction = retraction_method,
294-
stopping_criterion,
295202
kwargs...)
203+
opt = if isnothing(hessF)
204+
trust_regions(M, loss, gradF, x0; return_state = true, kwargs...)
205+
else
206+
trust_regions(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
207+
end
296208
# we unwrap DebugOptions here
297209
minimizer = Manopt.get_solver_result(opt)
298210
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
@@ -305,21 +217,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
305217
loss,
306218
gradF,
307219
x0;
308-
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
309-
evaluation::AbstractEvaluationType = InplaceEvaluation(),
310-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
311-
stepsize::Stepsize = DecreasingStepsize(; length = 2.0, shift = 2),
312-
kwargs...)
313-
opt = Frank_Wolfe_method(M,
314-
loss,
315-
gradF,
316-
x0;
317-
return_state = true,
318-
evaluation,
319-
retraction_method,
320-
stopping_criterion,
321-
stepsize,
322220
kwargs...)
221+
opt = Frank_Wolfe_method(M, loss, gradF, x0; return_state = true, kwargs...)
323222
# we unwrap DebugOptions here
324223
minimizer = Manopt.get_solver_result(opt)
325224
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
@@ -332,20 +231,22 @@ function SciMLBase.requiresgradient(opt::Union{
332231
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
333232
true
334233
end
234+
# TODO: WHY? they both still accept not passing it
335235
function SciMLBase.requireshessian(opt::Union{
336236
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
337237
true
338238
end
339239

340240
function build_loss(f::OptimizationFunction, prob, cb)
341-
function (::AbstractManifold, θ)
241+
return function (::AbstractManifold, θ)
342242
x = f.f(θ, prob.p)
343243
cb(x, θ)
344244
__x = first(x)
345245
return prob.sense === Optimization.MaxSense ? -__x : __x
346246
end
347247
end
348248

249+
#TODO: What does the “true” mean here?
349250
function build_gradF(f::OptimizationFunction{true})
350251
function g(M::AbstractManifold, G, θ)
351252
f.grad(G, θ)
@@ -356,6 +257,7 @@ function build_gradF(f::OptimizationFunction{true})
356257
f.grad(G, θ)
357258
return riemannian_gradient(M, θ, G)
358259
end
260+
return g
359261
end
360262

361263
function build_hessF(f::OptimizationFunction{true})
@@ -373,6 +275,7 @@ function build_hessF(f::OptimizationFunction{true})
373275
f.grad(G, θ)
374276
return riemannian_Hessian(M, θ, G, H, X)
375277
end
278+
return h
376279
end
377280

378281
function SciMLBase.__solve(cache::OptimizationCache{
@@ -395,8 +298,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
395298
LC,
396299
UC,
397300
S,
398-
O <:
399-
AbstractManoptOptimizer,
301+
O <: AbstractManoptOptimizer,
400302
D,
401303
P,
402304
C
@@ -418,6 +320,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
418320
u = θ,
419321
p = cache.p,
420322
objective = x[1])
323+
#TODO: What is this callback for?
421324
cb_call = cache.callback(opt_state, x...)
422325
if !(cb_call isa Bool)
423326
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
@@ -448,10 +351,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
448351
stopping_criterion = Manopt.StopAfterIteration(500)
449352
end
450353

354+
# TODO: With the new keyword warnings we can not just always pass down hessF!
451355
opt_res = call_manopt_optimizer(manifold, cache.opt, _loss, gradF, cache.u0;
452356
solver_kwarg..., stopping_criterion = stopping_criterion, hessF)
453357

454358
asc = get_stopping_criterion(opt_res.options)
359+
# TODO: Switch to `has_converged` once that was released.
455360
opt_ret = Manopt.indicates_convergence(asc) ? ReturnCode.Success : ReturnCode.Failure
456361

457362
return SciMLBase.build_solution(cache,

0 commit comments

Comments
 (0)