Skip to content

Commit 7af2b81

Browse files
wip
1 parent 136606b commit 7af2b81

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct GradientDescentOptimizer{
2929
end
3030

3131
function GradientDescentOptimizer(M::AbstractManifold;
32-
eval::AbstractEvaluationType = MutatingEvaluation(),
32+
eval::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
3333
stepsize::Stepsize = ArmijoLinesearch(M))
3434
GradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M, stepsize)
3535
end
@@ -99,7 +99,7 @@ struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType,
9999
end
100100

101101
function ConjugateGradientDescentOptimizer(M::AbstractManifold;
102-
eval::AbstractEvaluationType = MutatingEvaluation(),
102+
eval::AbstractEvaluationType = InplaceEvaluation(),
103103
stepsize::Stepsize = ArmijoLinesearch(M))
104104
ConjugateGradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M,
105105
stepsize)
@@ -143,7 +143,7 @@ struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
143143
end
144144

145145
function ParticleSwarmOptimizer(M::AbstractManifold;
146-
eval::AbstractEvaluationType = MutatingEvaluation(),
146+
eval::AbstractEvaluationType = InplaceEvaluation(),
147147
population_size::Int = 100,
148148
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
149149
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
@@ -195,7 +195,7 @@ struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
195195
end
196196

197197
function QuasiNewtonOptimizer(M::AbstractManifold;
198-
eval::AbstractEvaluationType = MutatingEvaluation(),
198+
eval::AbstractEvaluationType = InplaceEvaluation(),
199199
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
200200
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
201201
stepsize = WolfePowellLinesearch(M;
@@ -245,12 +245,13 @@ function build_loss(f::OptimizationFunction, prob, cur)
245245
end
246246

247247
function build_gradF(f::OptimizationFunction{true}, prob, cur)
248-
function (::AbstractManifold, G, θ)
249-
X = f.grad(G, θ, cur...)
248+
function (M::AbstractManifold, G, θ)
249+
f.grad(G, θ, cur...)
250+
G .= riemannian_gradient(M, θ, G)
250251
if prob.sense === Optimization.MaxSense
251-
return -X # TODO: check
252+
return -G # TODO: check
252253
else
253-
return X
254+
return G
254255
end
255256
end
256257
end

0 commit comments

Comments
 (0)