Skip to content

Commit bcbc649

Browse files
format
1 parent feb9d0c commit bcbc649

File tree

2 files changed

+137
-136
lines changed

2 files changed

+137
-136
lines changed

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 126 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -20,113 +20,112 @@ end
2020
## gradient descent
2121

2222
struct GradientDescentOptimizer{
23-
Teval <: AbstractEvaluationType,
24-
TM <: AbstractManifold,
25-
TLS <: Linesearch
26-
} <: AbstractManoptOptimizer
23+
Teval <: AbstractEvaluationType,
24+
TM <: AbstractManifold,
25+
TLS <: Linesearch
26+
} <: AbstractManoptOptimizer
2727
M::TM
2828
stepsize::TLS
2929
end
3030

3131
function GradientDescentOptimizer(M::AbstractManifold;
32-
eval::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
33-
stepsize::Stepsize = ArmijoLinesearch(M))
32+
eval::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
33+
stepsize::Stepsize = ArmijoLinesearch(M))
3434
GradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M, stepsize)
3535
end
3636

3737
function call_manopt_optimizer(opt::GradientDescentOptimizer{Teval},
38-
loss,
39-
gradF,
40-
x0,
41-
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
42-
Teval <:
43-
AbstractEvaluationType
44-
}
38+
loss,
39+
gradF,
40+
x0,
41+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
42+
Teval <:
43+
AbstractEvaluationType
44+
}
4545
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
4646
opts = gradient_descent(opt.M,
47-
loss,
48-
gradF,
49-
x0;
50-
return_state = true,
51-
evaluation = Teval(),
52-
stepsize = opt.stepsize,
53-
sckwarg...)
47+
loss,
48+
gradF,
49+
x0;
50+
return_state = true,
51+
evaluation = Teval(),
52+
stepsize = opt.stepsize,
53+
sckwarg...)
5454
# we unwrap DebugOptions here
5555
minimizer = Manopt.get_solver_result(opts)
5656
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
57-
:who_knows
57+
:who_knows
5858
end
5959

6060
## Nelder-Mead
6161

6262
struct NelderMeadOptimizer{
63-
TM <: AbstractManifold,
64-
} <: AbstractManoptOptimizer
63+
TM <: AbstractManifold,
64+
} <: AbstractManoptOptimizer
6565
M::TM
6666
end
6767

68-
6968
function call_manopt_optimizer(opt::NelderMeadOptimizer,
70-
loss,
71-
gradF,
72-
x0,
73-
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion})
69+
loss,
70+
gradF,
71+
x0,
72+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion})
7473
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
7574

7675
opts = NelderMead(opt.M,
77-
loss;
78-
return_state = true,
79-
sckwarg...)
76+
loss;
77+
return_state = true,
78+
sckwarg...)
8079
minimizer = Manopt.get_solver_result(opts)
8180
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
82-
:who_knows
81+
:who_knows
8382
end
8483

8584
## conjugate gradient descent
8685

8786
struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType,
88-
TM <: AbstractManifold, TLS <: Stepsize} <:
87+
TM <: AbstractManifold, TLS <: Stepsize} <:
8988
AbstractManoptOptimizer
9089
M::TM
9190
stepsize::TLS
9291
end
9392

9493
function ConjugateGradientDescentOptimizer(M::AbstractManifold;
95-
eval::AbstractEvaluationType = InplaceEvaluation(),
96-
stepsize::Stepsize = ArmijoLinesearch(M))
94+
eval::AbstractEvaluationType = InplaceEvaluation(),
95+
stepsize::Stepsize = ArmijoLinesearch(M))
9796
ConjugateGradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M,
98-
stepsize)
97+
stepsize)
9998
end
10099

101100
function call_manopt_optimizer(opt::ConjugateGradientDescentOptimizer{Teval},
102-
loss,
103-
gradF,
104-
x0,
105-
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
106-
Teval <:
107-
AbstractEvaluationType
108-
}
101+
loss,
102+
gradF,
103+
x0,
104+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
105+
Teval <:
106+
AbstractEvaluationType
107+
}
109108
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
110109
opts = conjugate_gradient_descent(opt.M,
111-
loss,
112-
gradF,
113-
x0;
114-
return_state = true,
115-
evaluation = Teval(),
116-
stepsize = opt.stepsize,
117-
sckwarg...)
110+
loss,
111+
gradF,
112+
x0;
113+
return_state = true,
114+
evaluation = Teval(),
115+
stepsize = opt.stepsize,
116+
sckwarg...)
118117
# we unwrap DebugOptions here
119118
minimizer = Manopt.get_solver_result(opts)
120119
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
121-
:who_knows
120+
:who_knows
122121
end
123122

124123
## particle swarm
125124

126125
struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
127-
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
128-
Tinvretr <: AbstractInverseRetractionMethod,
129-
Tvt <: AbstractVectorTransportMethod} <:
126+
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
127+
Tinvretr <: AbstractInverseRetractionMethod,
128+
Tvt <: AbstractVectorTransportMethod} <:
130129
AbstractManoptOptimizer
131130
M::TM
132131
retraction_method::Tretr
@@ -136,50 +135,50 @@ struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
136135
end
137136

138137
function ParticleSwarmOptimizer(M::AbstractManifold;
139-
eval::AbstractEvaluationType = InplaceEvaluation(),
140-
population_size::Int = 100,
141-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
142-
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
143-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M))
138+
eval::AbstractEvaluationType = InplaceEvaluation(),
139+
population_size::Int = 100,
140+
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
141+
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
142+
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M))
144143
ParticleSwarmOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
145-
typeof(inverse_retraction_method),
146-
typeof(vector_transport_method)}(M,
147-
retraction_method,
148-
inverse_retraction_method,
149-
vector_transport_method,
150-
population_size)
144+
typeof(inverse_retraction_method),
145+
typeof(vector_transport_method)}(M,
146+
retraction_method,
147+
inverse_retraction_method,
148+
vector_transport_method,
149+
population_size)
151150
end
152151

153152
function call_manopt_optimizer(opt::ParticleSwarmOptimizer{Teval},
154-
loss,
155-
gradF,
156-
x0,
157-
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
158-
Teval <:
159-
AbstractEvaluationType
160-
}
153+
loss,
154+
gradF,
155+
x0,
156+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
157+
Teval <:
158+
AbstractEvaluationType
159+
}
161160
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
162161
initial_population = vcat([x0], [rand(opt.M) for _ in 1:(opt.population_size - 1)])
163162
opts = particle_swarm(opt.M,
164-
loss;
165-
x0 = initial_population,
166-
n = opt.population_size,
167-
return_state = true,
168-
retraction_method = opt.retraction_method,
169-
inverse_retraction_method = opt.inverse_retraction_method,
170-
vector_transport_method = opt.vector_transport_method,
171-
sckwarg...)
163+
loss;
164+
x0 = initial_population,
165+
n = opt.population_size,
166+
return_state = true,
167+
retraction_method = opt.retraction_method,
168+
inverse_retraction_method = opt.inverse_retraction_method,
169+
vector_transport_method = opt.vector_transport_method,
170+
sckwarg...)
172171
# we unwrap DebugOptions here
173172
minimizer = Manopt.get_solver_result(opts)
174173
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
175-
:who_knows
174+
:who_knows
176175
end
177176

178177
## quasi Newton
179178

180179
struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
181-
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
182-
Tvt <: AbstractVectorTransportMethod, TLS <: Stepsize} <:
180+
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
181+
Tvt <: AbstractVectorTransportMethod, TLS <: Stepsize} <:
183182
AbstractManoptOptimizer
184183
M::TM
185184
retraction_method::Tretr
@@ -188,43 +187,43 @@ struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
188187
end
189188

190189
function QuasiNewtonOptimizer(M::AbstractManifold;
191-
eval::AbstractEvaluationType = InplaceEvaluation(),
192-
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
193-
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
194-
stepsize = WolfePowellLinesearch(M;
195-
retraction_method = retraction_method,
196-
vector_transport_method = vector_transport_method,
197-
linesearch_stopsize = 1e-12))
190+
eval::AbstractEvaluationType = InplaceEvaluation(),
191+
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
192+
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
193+
stepsize = WolfePowellLinesearch(M;
194+
retraction_method = retraction_method,
195+
vector_transport_method = vector_transport_method,
196+
linesearch_stopsize = 1e-12))
198197
QuasiNewtonOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
199-
typeof(vector_transport_method), typeof(stepsize)}(M,
200-
retraction_method,
201-
vector_transport_method,
202-
stepsize)
198+
typeof(vector_transport_method), typeof(stepsize)}(M,
199+
retraction_method,
200+
vector_transport_method,
201+
stepsize)
203202
end
204203

205204
function call_manopt_optimizer(opt::QuasiNewtonOptimizer{Teval},
206-
loss,
207-
gradF,
208-
x0,
209-
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
210-
Teval <:
211-
AbstractEvaluationType
212-
}
205+
loss,
206+
gradF,
207+
x0,
208+
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
209+
Teval <:
210+
AbstractEvaluationType
211+
}
213212
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
214213
opts = quasi_Newton(opt.M,
215-
loss,
216-
gradF,
217-
x0;
218-
return_state = true,
219-
evaluation = Teval(),
220-
retraction_method = opt.retraction_method,
221-
vector_transport_method = opt.vector_transport_method,
222-
stepsize = opt.stepsize,
223-
sckwarg...)
214+
loss,
215+
gradF,
216+
x0;
217+
return_state = true,
218+
evaluation = Teval(),
219+
retraction_method = opt.retraction_method,
220+
vector_transport_method = opt.vector_transport_method,
221+
stepsize = opt.stepsize,
222+
sckwarg...)
224223
# we unwrap DebugOptions here
225224
minimizer = Manopt.get_solver_result(opts)
226225
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
227-
:who_knows
226+
:who_knows
228227
end
229228

230229
## Optimization.jl stuff
@@ -255,15 +254,15 @@ end
255254
# 3) add callbacks to Manopt.jl
256255

257256
function SciMLBase.__solve(prob::OptimizationProblem,
258-
opt::AbstractManoptOptimizer,
259-
data = Optimization.DEFAULT_DATA;
260-
callback = (args...) -> (false),
261-
maxiters::Union{Number, Nothing} = nothing,
262-
maxtime::Union{Number, Nothing} = nothing,
263-
abstol::Union{Number, Nothing} = nothing,
264-
reltol::Union{Number, Nothing} = nothing,
265-
progress = false,
266-
kwargs...)
257+
opt::AbstractManoptOptimizer,
258+
data = Optimization.DEFAULT_DATA;
259+
callback = (args...) -> (false),
260+
maxiters::Union{Number, Nothing} = nothing,
261+
maxtime::Union{Number, Nothing} = nothing,
262+
abstol::Union{Number, Nothing} = nothing,
263+
reltol::Union{Number, Nothing} = nothing,
264+
progress = false,
265+
kwargs...)
267266
local x, cur, state
268267

269268
manifold = haskey(prob.kwargs, :manifold) ? prob.kwargs[:manifold] : nothing
@@ -295,12 +294,12 @@ function SciMLBase.__solve(prob::OptimizationProblem,
295294
opt_res, opt_ret = call_manopt_optimizer(opt, _loss, gradF, prob.u0, stopping_criterion)
296295

297296
return SciMLBase.build_solution(SciMLBase.DefaultOptimizationCache(prob.f, prob.p),
298-
opt,
299-
opt_res.minimizer,
300-
prob.sense === Optimization.MaxSense ?
301-
-opt_res.minimum : opt_res.minimum;
302-
original = opt_res.options,
303-
retcode = opt_ret)
297+
opt,
298+
opt_res.minimizer,
299+
prob.sense === Optimization.MaxSense ?
300+
-opt_res.minimum : opt_res.minimum;
301+
original = opt_res.options,
302+
retcode = opt_ret)
304303
end
305304

306-
end # module OptimizationManopt
305+
end # module OptimizationManopt

0 commit comments

Comments
 (0)