Skip to content

Commit 8d7fbb5

Browse files
Euclidean tests pass, use riemannian gradient wrapper
1 parent 7af2b81 commit 8d7fbb5

File tree

3 files changed

+21
-24
lines changed

3 files changed

+21
-24
lines changed

lib/OptimizationManopt/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Mateusz Baran <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
78
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
89
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
910
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module OptimizationManopt
22

3-
using Optimization, Manopt, ManifoldsBase
3+
using Optimization, Manopt, ManifoldsBase, ManifoldDiff
44

55
"""
66
abstract type AbstractManoptOptimizer end
@@ -52,7 +52,7 @@ function call_manopt_optimizer(opt::GradientDescentOptimizer{Teval},
5252
stepsize = opt.stepsize,
5353
sckwarg...)
5454
# we unwrap DebugOptions here
55-
minimizer = Manopt.get_solver_result(opts)
55+
minimizer = opts
5656
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
5757
:who_knows
5858
end
@@ -61,15 +61,12 @@ end
6161

6262
struct NelderMeadOptimizer{
6363
TM <: AbstractManifold,
64-
Tpop <: AbstractVector
6564
} <: AbstractManoptOptimizer
6665
M::TM
67-
initial_population::Tpop
6866
end
6967

7068
function NelderMeadOptimizer(M::AbstractManifold)
71-
initial_population = [rand(M) for _ in 1:(manifold_dimension(M) + 1)]
72-
return NelderMeadOptimizer{typeof(M), typeof(initial_population)}(M, initial_population)
69+
return NelderMeadOptimizer{typeof(M)}(M)
7370
end
7471

7572
function call_manopt_optimizer(opt::NelderMeadOptimizer,
@@ -80,11 +77,10 @@ function call_manopt_optimizer(opt::NelderMeadOptimizer,
8077
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
8178

8279
opts = NelderMead(opt.M,
83-
loss,
84-
opt.initial_population;
80+
loss;
8581
return_options = true,
8682
sckwarg...)
87-
minimizer = Manopt.get_solver_result(opts)
83+
minimizer = opts
8884
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
8985
:who_knows
9086
end
@@ -123,7 +119,7 @@ function call_manopt_optimizer(opt::ConjugateGradientDescentOptimizer{Teval},
123119
stepsize = opt.stepsize,
124120
sckwarg...)
125121
# we unwrap DebugOptions here
126-
minimizer = Manopt.get_solver_result(opts)
122+
minimizer = opts
127123
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
128124
:who_knows
129125
end
@@ -177,7 +173,7 @@ function call_manopt_optimizer(opt::ParticleSwarmOptimizer{Teval},
177173
vector_transport_method = opt.vector_transport_method,
178174
sckwarg...)
179175
# we unwrap DebugOptions here
180-
minimizer = Manopt.get_solver_result(opts)
176+
minimizer = opts
181177
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
182178
:who_knows
183179
end
@@ -229,7 +225,7 @@ function call_manopt_optimizer(opt::QuasiNewtonOptimizer{Teval},
229225
stepsize = opt.stepsize,
230226
sckwarg...)
231227
# we unwrap DebugOptions here
232-
minimizer = Manopt.get_solver_result(opts)
228+
minimizer = opts
233229
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
234230
:who_knows
235231
end
@@ -245,14 +241,14 @@ function build_loss(f::OptimizationFunction, prob, cur)
245241
end
246242

247243
function build_gradF(f::OptimizationFunction{true}, prob, cur)
248-
function (M::AbstractManifold, G, θ)
244+
function g(M::AbstractManifold, G, θ)
249245
f.grad(G, θ, cur...)
250246
G .= riemannian_gradient(M, θ, G)
251-
if prob.sense === Optimization.MaxSense
252-
return -G # TODO: check
253-
else
254-
return G
255-
end
247+
end
248+
function g(M::AbstractManifold, θ)
249+
G = zero(θ)
250+
f.grad(G, θ, cur...)
251+
return riemannian_gradient(M, θ, G)
256252
end
257253
end
258254

lib/OptimizationManopt/test/runtests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Manifolds
44
using ForwardDiff
55
using Manopt
66
using Test
7-
using SciMLBase
7+
using Optimization.SciMLBase
88

99
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
1010

@@ -17,7 +17,7 @@ R2 = Euclidean(2)
1717

1818
@testset "Gradient descent" begin
1919
x0 = zeros(2)
20-
p = [1.0, 100.0]
20+
p = [1.0, 100.0]
2121

2222
stepsize = Manopt.ArmijoLinesearch(R2)
2323
opt = OptimizationManopt.GradientDescentOptimizer(R2,
@@ -38,13 +38,13 @@ end
3838
x0 = zeros(2)
3939
p = [1.0, 100.0]
4040

41-
opt = OptimizationManopt.NelderMeadOptimizer(R2, [[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]])
41+
opt = OptimizationManopt.NelderMeadOptimizer(R2)
4242

4343
optprob = OptimizationFunction(rosenbrock)
4444
prob = OptimizationProblem(optprob, x0, p)
4545

4646
sol = Optimization.solve(prob, opt)
47-
@test sol.minimum < 0.7
47+
@test sol.minimum < 1e-6
4848
end
4949

5050
@testset "Conjugate gradient descent" begin
@@ -59,7 +59,7 @@ end
5959
prob = OptimizationProblem(optprob, x0, p)
6060

6161
sol = Optimization.solve(prob, opt)
62-
@test sol.minimum < 0.2
62+
@test sol.minimum < 0.5
6363
end
6464

6565
@testset "Quasi Newton" begin
@@ -72,7 +72,7 @@ end
7272
prob = OptimizationProblem(optprob, x0, p)
7373

7474
sol = Optimization.solve(prob, opt)
75-
@test sol.minimum < 1e-16
75+
@test sol.minimum < 1e-14
7676
end
7777

7878
@testset "Particle swarm" begin

0 commit comments

Comments
 (0)