Skip to content

Commit 7d79171

Browse files
Pass in manifold as kwarg to OptimizationProblem and throw error if not passes or doesn't match solver
1 parent 3bf8327 commit 7d79171

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ struct NelderMeadOptimizer{
6565
M::TM
6666
end
6767

68-
function NelderMeadOptimizer(M::AbstractManifold)
69-
return NelderMeadOptimizer{typeof(M)}(M)
70-
end
7168

7269
function call_manopt_optimizer(opt::NelderMeadOptimizer,
7370
loss,
@@ -269,6 +266,12 @@ function SciMLBase.__solve(prob::OptimizationProblem,
269266
kwargs...)
270267
local x, cur, state
271268

269+
manifold = haskey(prob.kwargs, :manifold) ? prob.kwargs[:manifold] : nothing
270+
271+
if manifold === nothing || manifold !== opt.M
272+
throw(ArgumentError("Either manifold not specified in the problem `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))` or it doesn't match the manifold specified in the optimizer `$(opt.M)`"))
273+
end
274+
272275
if data !== Optimization.DEFAULT_DATA
273276
maxiters = length(data)
274277
end

lib/OptimizationManopt/test/runtests.jl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616

1717
R2 = Euclidean(2)
1818

19-
@testset "Gradient descent" begin
19+
@testset "Error on no or mismatching manifolds" begin
2020
x0 = zeros(2)
2121
p = [1.0, 100.0]
2222

@@ -26,11 +26,28 @@ R2 = Euclidean(2)
2626

2727
optprob_forwarddiff = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
2828
prob_forwarddiff = OptimizationProblem(optprob_forwarddiff, x0, p)
29+
@test_throws ArgumentError("Either manifold not specified in the problem `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))` or it doesn't match the manifold specified in the optimizer `$(opt.M)`") Optimization.solve(prob_forwarddiff, opt)
30+
31+
optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
32+
prob = OptimizationProblem(optprob, x0, p; manifold = SymmetricPositiveDefinite(5))
33+
@test_throws ArgumentError("Either manifold not specified in the problem `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))` or it doesn't match the manifold specified in the optimizer `$(opt.M)`") Optimization.solve(prob, opt)
34+
end
35+
36+
@testset "Gradient descent" begin
37+
x0 = zeros(2)
38+
p = [1.0, 100.0]
39+
40+
stepsize = Manopt.ArmijoLinesearch(R2)
41+
opt = OptimizationManopt.GradientDescentOptimizer(R2,
42+
stepsize = stepsize)
43+
44+
optprob_forwarddiff = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
45+
prob_forwarddiff = OptimizationProblem(optprob_forwarddiff, x0, p; manifold = R2)
2946
sol = Optimization.solve(prob_forwarddiff, opt)
3047
@test sol.minimum < 0.2
3148

3249
optprob_grad = OptimizationFunction(rosenbrock; grad = rosenbrock_grad!)
33-
prob_grad = OptimizationProblem(optprob_grad, x0, p)
50+
prob_grad = OptimizationProblem(optprob_grad, x0, p; manifold = R2)
3451
sol = Optimization.solve(prob_grad, opt)
3552
@test sol.minimum < 0.2
3653
end
@@ -42,7 +59,7 @@ end
4259
opt = OptimizationManopt.NelderMeadOptimizer(R2)
4360

4461
optprob = OptimizationFunction(rosenbrock)
45-
prob = OptimizationProblem(optprob, x0, p)
62+
prob = OptimizationProblem(optprob, x0, p; manifold = R2)
4663

4764
sol = Optimization.solve(prob, opt)
4865
@test sol.minimum < 1e-6
@@ -57,7 +74,7 @@ end
5774
stepsize = stepsize)
5875

5976
optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
60-
prob = OptimizationProblem(optprob, x0, p)
77+
prob = OptimizationProblem(optprob, x0, p; manifold = R2)
6178

6279
sol = Optimization.solve(prob, opt)
6380
@test sol.minimum < 0.5
@@ -70,7 +87,7 @@ end
7087
opt = OptimizationManopt.QuasiNewtonOptimizer(R2)
7188

7289
optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
73-
prob = OptimizationProblem(optprob, x0, p)
90+
prob = OptimizationProblem(optprob, x0, p; manifold = R2)
7491

7592
sol = Optimization.solve(prob, opt)
7693
@test sol.minimum < 1e-14
@@ -83,7 +100,7 @@ end
83100
opt = OptimizationManopt.ParticleSwarmOptimizer(R2)
84101

85102
optprob = OptimizationFunction(rosenbrock)
86-
prob = OptimizationProblem(optprob, x0, p)
103+
prob = OptimizationProblem(optprob, x0, p; manifold = R2)
87104

88105
sol = Optimization.solve(prob, opt)
89106
@test sol.minimum < 0.1
@@ -112,7 +129,7 @@ end
112129
f(x, p = nothing) = sum(distance(M, x, data2[i])^2 for i in 1:m)
113130

114131
optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
115-
prob = OptimizationProblem(optf, data2[1])
132+
prob = OptimizationProblem(optf, data2[1]; manifold = M)
116133

117134
opt = OptimizationManopt.GradientDescentOptimizer(M)
118135
@time sol = Optimization.solve(prob, opt)

0 commit comments

Comments
 (0)