Skip to content

Commit 3bf8327

Browse files
Pass return_state and add spd test
1 parent 8d7fbb5 commit 3bf8327

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ function call_manopt_optimizer(opt::GradientDescentOptimizer{Teval},
4747
loss,
4848
gradF,
4949
x0;
50-
return_options = true,
50+
return_state = true,
5151
evaluation = Teval(),
5252
stepsize = opt.stepsize,
5353
sckwarg...)
5454
# we unwrap DebugOptions here
55-
minimizer = opts
55+
minimizer = Manopt.get_solver_result(opts)
5656
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
5757
:who_knows
5858
end
@@ -78,9 +78,9 @@ function call_manopt_optimizer(opt::NelderMeadOptimizer,
7878

7979
opts = NelderMead(opt.M,
8080
loss;
81-
return_options = true,
81+
return_state = true,
8282
sckwarg...)
83-
minimizer = opts
83+
minimizer = Manopt.get_solver_result(opts)
8484
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
8585
:who_knows
8686
end
@@ -114,12 +114,12 @@ function call_manopt_optimizer(opt::ConjugateGradientDescentOptimizer{Teval},
114114
loss,
115115
gradF,
116116
x0;
117-
return_options = true,
117+
return_state = true,
118118
evaluation = Teval(),
119119
stepsize = opt.stepsize,
120120
sckwarg...)
121121
# we unwrap DebugOptions here
122-
minimizer = opts
122+
minimizer = Manopt.get_solver_result(opts)
123123
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
124124
:who_knows
125125
end
@@ -167,13 +167,13 @@ function call_manopt_optimizer(opt::ParticleSwarmOptimizer{Teval},
167167
loss;
168168
x0 = initial_population,
169169
n = opt.population_size,
170-
return_options = true,
170+
return_state = true,
171171
retraction_method = opt.retraction_method,
172172
inverse_retraction_method = opt.inverse_retraction_method,
173173
vector_transport_method = opt.vector_transport_method,
174174
sckwarg...)
175175
# we unwrap DebugOptions here
176-
minimizer = opts
176+
minimizer = Manopt.get_solver_result(opts)
177177
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
178178
:who_knows
179179
end
@@ -218,14 +218,14 @@ function call_manopt_optimizer(opt::QuasiNewtonOptimizer{Teval},
218218
loss,
219219
gradF,
220220
x0;
221-
return_options = true,
221+
return_state = true,
222222
evaluation = Teval(),
223223
retraction_method = opt.retraction_method,
224224
vector_transport_method = opt.vector_transport_method,
225225
stepsize = opt.stepsize,
226226
sckwarg...)
227227
# we unwrap DebugOptions here
228-
minimizer = opts
228+
minimizer = Manopt.get_solver_result(opts)
229229
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
230230
:who_knows
231231
end

lib/OptimizationManopt/test/runtests.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using OptimizationManopt
22
using Optimization
33
using Manifolds
4-
using ForwardDiff
4+
using ForwardDiff, Zygote, Enzyme
55
using Manopt
66
using Test
77
using Optimization.SciMLBase
8+
using LinearAlgebra
89

910
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
1011

@@ -98,4 +99,23 @@ end
9899
optprob_cons = OptimizationFunction(rosenbrock; grad = rosenbrock_grad!, cons = cons)
99100
prob_cons = OptimizationProblem(optprob_cons, x0, p)
100101
@test_throws SciMLBase.IncompatibleOptimizerError Optimization.solve(prob_cons, opt)
102+
end
103+
104+
@testset "SPD Manifold" begin
105+
M = SymmetricPositiveDefinite(5)
106+
m = 100
107+
σ = 0.005
108+
q = Matrix{Float64}(I, 5, 5) .+ 2.0
109+
data2 = [exp(M, q, σ * rand(M; vector_at=q)) for i in 1:m];
110+
111+
f(M, x, p = nothing) = sum(distance(M, x, data2[i])^2 for i in 1:m)
112+
f(x, p = nothing) = sum(distance(M, x, data2[i])^2 for i in 1:m)
113+
114+
optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
115+
prob = OptimizationProblem(optf, data2[1])
116+
117+
opt = OptimizationManopt.GradientDescentOptimizer(M)
118+
@time sol = Optimization.solve(prob, opt)
119+
120+
@test sol.u q
101121
end

0 commit comments

Comments
 (0)