Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
OptimizationLBFGSB = "22f7324a-a79d-40f2-bebe-3af60c77bd15"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OptimizationLBFGSB = "22f7324a-a79d-40f2-bebe-3af60c77bd15"
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Expand All @@ -103,10 +104,6 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[targets]
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff",
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
"OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays",
"Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff", "Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays", "Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]
3 changes: 3 additions & 0 deletions lib/OptimizationBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
version = "4.0.1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "4.0.2"

Expand All @@ -16,6 +17,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down Expand Up @@ -59,6 +61,7 @@ SciMLBase = "2.122.1"
SparseConnectivityTracer = "0.6, 1"
SparseMatrixColorings = "0.4"
SymbolicAnalysis = "0.3"
SymbolicIndexingInterface = "0.3.46"
Zygote = "0.6.67, 0.7"
julia = "1.10"

Expand Down
7 changes: 6 additions & 1 deletion lib/OptimizationBase/src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ import SciMLBase: solve, init, solve!, __init, __solve,
allowsconstraints, requiresconstraints,
allowscallback, requiresgradient,
requireshessian, requiresconsjac,
requiresconshess
requiresconshess, wrap_sol, has_kwargs,
get_root_indp, get_updated_symbolic_problem,
get_concrete_p, get_concrete_u0, promote_u0,
KeywordArgError, checkkwargs

using SymbolicIndexingInterface: SymbolicIndexingInterface

export ObjSense, MaxSense, MinSense
export allowsbounds, requiresbounds, allowsconstraints, requiresconstraints,
Expand Down
133 changes: 123 additions & 10 deletions lib/OptimizationBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,33 @@ from NLopt for an example. The common local optimizer arguments are:
- `local_reltol`: relative tolerance in changes of the objective value
- `local_options`: `NamedTuple` of keyword arguments for local optimizer
"""
function solve(prob::SciMLBase.OptimizationProblem, alg, args...;
kwargs...)::SciMLBase.AbstractOptimizationSolution
if SciMLBase.has_init(alg)
solve!(init(prob, alg, args...; kwargs...))
function solve(prob::SciMLBase.OptimizationProblem, args...; sensealg = nothing,
u0 = nothing, p = nothing, wrap = Val(true), kwargs...)::SciMLBase.AbstractOptimizationSolution
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end

u0 = u0 !== nothing ? u0 : prob.u0
p = p !== nothing ? p : prob.p

if wrap isa Val{true}
wrap_sol(solve_up(prob,
sensealg,
u0,
p,
args...;
originator = SciMLBase.ChainRulesOriginator(),
kwargs...))
else
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
end
_check_opt_alg(prob, alg; kwargs...)
__solve(prob, alg, args...; kwargs...)
solve_up(prob,
sensealg,
u0,
p,
args...;
originator = SciMLBase.ChainRulesOrginator(),
kwargs...)
end
end
end

function solve(
prob::SciMLBase.EnsembleProblem{T}, args...; kwargs...) where {T <:
Expand Down Expand Up @@ -216,3 +231,101 @@ end
function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
throw(OptimizerMissingError(alg))
end

function solve_up(prob::SciMLBase.OptimizationProblem, sensealg, u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
kwargs...)
alg = extract_opt_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
_prob = get_concrete_problem(prob; u0 = u0, p = p, kwargs...)
if length(args) > 1
solve_call(_prob, alg, Base.tail(args)..., kwargs...)
else
solve_call(_prob, alg; kwargs...)
end
end

function solve_call(_prob, alg, args...; merge_callbacks = true, kwargshandle = nothing,
kwargs...)
kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ?
_prob.kwargs[:kwargshandle] : kwargshandle

if has_kwargs(_prob)
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
end

#checkkwargs(kwargshandle; kwargs...)

if SciMLBase.has_init(alg)
solve!(init(_prob, alg, args...; kwargs...))
else
if _prob.u0 !== nothing && !isconcretetype(eltype(_prob.u0))
throw(SciMLBase.NonConcreteEltypeError(eltype(_prob.u0)))
end
_check_opt_alg(_prob, alg; kwargs...)
__solve(_prob, alg, args...; kwargs...)
end
end

function get_concrete_problem(prob::OptimizationProblem; kwargs...)
oldprob = prob
prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...)
if prob !== oldprob
kwargs = (;kwargs..., u0 = SymbolicIndexingInterface.state_values(prob), p = SymbolicIndexingInterface.parameter_values(prob))
end
p = get_concrete_p(prob, kwargs)
u0 = get_concrete_u0(prob, false, nothing, kwargs)
u0 = promote_u0(u0, p, nothing)
remake(prob; u0 = u0, p = p)

end


@inline function extract_opt_alg(solve_args, solve_kwargs, prob_kwargs)
if isempty(solve_args) || isnothing(first(solve_args))
if haskey(solve_kwargs, :alg)
solve_kwargs[:alg]
elseif haskey(prob_kwargs, :alg)
prob_kwargs[:alg]
else
nothing
end
else
first(solve_args)
end
end


function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
kwargs...)
alg = extract_opt_alg(args, kwargs, prob.kwargs)
_prob = get_concrete_problem(prob; u0 = u0, p = p, kwargs...)

if has_kwargs(_prob)
kwargs = isempty(_porb.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
end

if length(args) > 1
_concrete_solve_forward(_prob, alg, sensealg, u0, p, originator,
Base.tail(args)...; kwargs...)
else
_concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...)
end
end

function _solve_adjoint(_prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
kwargs...)
alg = extract_alg(args, kwargs, prob.kwargs)

_prob = get_concrete_problem(prob; u0 = u0, p = p, kwargs...)

if has_kwargs(_prob)
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
end

if length(args) > 1
_concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator,
Base.tail(args)...; kwargs...)
else
_concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...)
end
end
10 changes: 10 additions & 0 deletions lib/OptimizationMultistartOptimization/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ using Test, ReverseDiff
OptimizationNLopt.Opt(:LD_LBFGS, 2))
@test 10 * sol.objective < l1
end

rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
x0 = zeros(2)
_p = [1.0, 100.0]
l1 = rosenbrock(x0, _p)
f = OptimizationFunction(rosenbrock, OptimizationBase.AutoForwardDiff())
prob = OptimizationBase.OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [1.5, 1.5])
sol = solve(prob, OptimizationMultistartOptimization.TikTak(100),
OptimizationNLopt.Opt(:LD_LBFGS, 2))
@test 10 * sol.objective < l1
Loading