Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/OptimizationManopt/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
LinearAlgebra = "1.10"
ManifoldDiff = "0.3.10"
Manifolds = "0.9.18"
ManifoldsBase = "0.15.10"
Manopt = "0.4.63"
ManifoldDiff = "0.4"
Manifolds = "0.10"
ManifoldsBase = "1"
Manopt = "0.5"
Optimization = "4.4"
Reexport = "1.2"
julia = "1.10"
Expand Down
163 changes: 34 additions & 129 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,14 @@ function call_manopt_optimizer(
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M),
kwargs...)
opts = gradient_descent(M,
opts = Manopt.gradient_descent(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
stepsize,
stopping_criterion,
kwargs...)
# we unwrap DebugOptions here
return_state = true, # return the (full, decorated) solver state
kwargs...
)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
end
Expand All @@ -90,13 +84,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMea
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
kwargs...)
opts = NelderMead(M,
loss;
return_state = true,
stopping_criterion,
kwargs...)
opts = NelderMead(M, loss; return_state = true, kwargs...)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
end
Expand All @@ -109,19 +98,14 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M),
kwargs...)
opts = conjugate_gradient_descent(M,
opts = Manopt.conjugate_gradient_descent(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
stepsize,
stopping_criterion,
kwargs...)
kwargs...
)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
Expand All @@ -135,25 +119,10 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
population_size::Int = 100,
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
kwargs...)
initial_population = vcat([x0], [rand(M) for _ in 1:(population_size - 1)])
opts = particle_swarm(M,
loss;
x0 = initial_population,
n = population_size,
return_state = true,
retraction_method,
inverse_retraction_method,
vector_transport_method,
stopping_criterion,
kwargs...)
# we unwrap DebugOptions here
swarm = [x0, [rand(M) for _ in 1:(population_size - 1)]...]
opts = particle_swarm(M, loss, swarm; return_state = true, kwargs...)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
end
Expand All @@ -167,27 +136,9 @@ function call_manopt_optimizer(M::Manopt.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
stepsize = WolfePowellLinesearch(M;
retraction_method = retraction_method,
vector_transport_method = vector_transport_method,
linesearch_stopsize = 1e-12),
kwargs...
)
opts = quasi_Newton(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
retraction_method,
vector_transport_method,
stepsize,
stopping_criterion,
kwargs...)
opts = quasi_Newton(M, loss, gradF, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts)
Expand All @@ -200,18 +151,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
basis = Manopt.DefaultOrthonormalBasis(),
kwargs...)
opt = cma_es(M,
loss,
x0;
return_state = true,
stopping_criterion,
kwargs...)
opt = cma_es(M, loss, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -224,21 +165,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
kwargs...)
opt = convex_bundle_method!(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
retraction_method,
vector_transport_method,
stopping_criterion,
kwargs...)
opt = convex_bundle_method(M, loss, gradF, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -252,21 +180,13 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
gradF,
x0;
hessF = nothing,
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
kwargs...)
opt = adaptive_regularization_with_cubics(M,
loss,
gradF,
hessF,
x0;
return_state = true,
evaluation,
retraction_method,
stopping_criterion,
kwargs...)
# we unwrap DebugOptions here

opt = if isnothing(hessF)
adaptive_regularization_with_cubics(M, loss, gradF, x0; return_state = true, kwargs...)
else
adaptive_regularization_with_cubics(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
end
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
end
Expand All @@ -279,20 +199,12 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
gradF,
x0;
hessF = nothing,
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
kwargs...)
opt = trust_regions(M,
loss,
gradF,
hessF,
x0;
return_state = true,
evaluation,
retraction = retraction_method,
stopping_criterion,
kwargs...)
opt = if isnothing(hessF)
trust_regions(M, loss, gradF, x0; return_state = true, kwargs...)
else
trust_regions(M, loss, gradF, hessF, x0; return_state = true, kwargs...)
end
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -305,21 +217,8 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
loss,
gradF,
x0;
stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet},
evaluation::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
stepsize::Stepsize = DecreasingStepsize(; length = 2.0, shift = 2),
kwargs...)
opt = Frank_Wolfe_method(M,
loss,
gradF,
x0;
return_state = true,
evaluation,
retraction_method,
stopping_criterion,
stepsize,
kwargs...)
opt = Frank_Wolfe_method(M, loss, gradF, x0; return_state = true, kwargs...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opt)
return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opt)
Expand All @@ -332,20 +231,22 @@ function SciMLBase.requiresgradient(opt::Union{
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
true
end
# TODO: WHY? they both still accept not passing it
function SciMLBase.requireshessian(opt::Union{
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
true
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this function defined and what is it for?

The current definition here is not correct, both ARC and TR can perform their own (actually quite good) approximation of the hessian – similar to what QN does.
So they do not need a Hessian, but the exact one of course performs a bit better than the approximate one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a trait for checking whether a solver requires that the Hessian function is required in order to use the solver. For example, if your solver uses prob.f.hess then this should be true, so that way you can fail if a second order AD method is not given.

if it's not required then this should be false. What this will do is, if true, turn on an error message that says "prob.f.hess is not defined and therefore you cannot use this method" (not exactly, but high level that's pretty much what it's for, for higher level error messages and reporting)

function build_loss(f::OptimizationFunction, prob, cb)
function (::AbstractManifold, θ)
return function (::AbstractManifold, θ)
x = f.f(θ, prob.p)
cb(x, θ)
__x = first(x)
return prob.sense === Optimization.MaxSense ? -__x : __x
end
end

#TODO: What does the “true” mean here?
function build_gradF(f::OptimizationFunction{true})
function g(M::AbstractManifold, G, θ)
f.grad(G, θ)
Expand All @@ -356,6 +257,7 @@ function build_gradF(f::OptimizationFunction{true})
f.grad(G, θ)
return riemannian_gradient(M, θ, G)
end
return g
end

function build_hessF(f::OptimizationFunction{true})
Expand All @@ -373,6 +275,7 @@ function build_hessF(f::OptimizationFunction{true})
f.grad(G, θ)
return riemannian_Hessian(M, θ, G, H, X)
end
return h
end

function SciMLBase.__solve(cache::OptimizationCache{
Expand All @@ -395,8 +298,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
LC,
UC,
S,
O <:
AbstractManoptOptimizer,
O <: AbstractManoptOptimizer,
D,
P,
C
Expand All @@ -418,6 +320,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
u = θ,
p = cache.p,
objective = x[1])
#TODO: What is this callback for?
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -448,10 +351,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
stopping_criterion = Manopt.StopAfterIteration(500)
end

# TODO: With the new keyword warnings we can not just always pass down hessF!
opt_res = call_manopt_optimizer(manifold, cache.opt, _loss, gradF, cache.u0;
solver_kwarg..., stopping_criterion = stopping_criterion, hessF)

asc = get_stopping_criterion(opt_res.options)
# TODO: Switch to `has_converged` once that was released.
opt_ret = Manopt.indicates_convergence(asc) ? ReturnCode.Success : ReturnCode.Failure

return SciMLBase.build_solution(cache,
Expand Down
Loading