Skip to content

Commit daa0b2e

Browse files
Update stepsize kwargs to use the default constructor
1 parent 88e06da commit daa0b2e

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function call_manopt_optimizer(
6767
x0;
6868
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
6969
evaluation::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
70-
stepsize::Stepsize = ArmijoLinesearchStepsize(M),
70+
stepsize::Stepsize = default_stepsize(M, GradientDescentState),
7171
kwargs...)
7272
opts = gradient_descent(M,
7373
loss,
@@ -111,7 +111,7 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
111111
x0;
112112
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
113113
evaluation::AbstractEvaluationType = InplaceEvaluation(),
114-
stepsize::Stepsize = ArmijoLinesearch(M),
114+
stepsize::Stepsize = default_stepsize(M, ConjugateGradientDescentState),
115115
kwargs...)
116116
opts = conjugate_gradient_descent(M,
117117
loss,
@@ -308,7 +308,7 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
308308
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
309309
evaluation::AbstractEvaluationType = InplaceEvaluation(),
310310
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
311-
stepsize::Stepsize = DecreasingLength(; length = 2.0, shift = 2),
311+
stepsize::Stepsize = default_stepsize(M, FrankWolfeState),
312312
kwargs...)
313313
opt = Frank_Wolfe_method(M,
314314
loss,

lib/OptimizationManopt/test/runtests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ end
3333
x0 = zeros(2)
3434
p = [1.0, 100.0]
3535

36-
stepsize = Manopt.ArmijoLinesearch()
36+
stepsize = default_stepsize(
37+
R2, GradientDescentState
38+
)
3739
opt = OptimizationManopt.GradientDescentOptimizer()
3840

3941
optprob_forwarddiff = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme())
@@ -65,13 +67,12 @@ end
6567
x0 = zeros(2)
6668
p = [1.0, 100.0]
6769

68-
stepsize = Manopt.ArmijoLinesearch(R2)
6970
opt = OptimizationManopt.ConjugateGradientDescentOptimizer()
7071

7172
optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
7273
prob = OptimizationProblem(optprob, x0, p; manifold = R2)
7374

74-
sol = Optimization.solve(prob, opt, stepsize = stepsize)
75+
sol = Optimization.solve(prob, opt)
7576
@test sol.minimum < 0.5
7677
end
7778

0 commit comments

Comments
 (0)