Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1d322a1
Create OptimizationODE.jl
ParasPuneetSingh May 22, 2025
a81ae86
Create runtests.jl
ParasPuneetSingh May 22, 2025
ab80dcc
Create Project.toml
ParasPuneetSingh May 22, 2025
7558b76
Update CI.yml
ChrisRackauckas May 24, 2025
967ce77
Update OptimizationODE.jl
ParasPuneetSingh May 26, 2025
38733ac
Update runtests.jl
ParasPuneetSingh May 26, 2025
e2b310c
Update OptimizationODE.jl
ParasPuneetSingh May 27, 2025
6b79113
Update runtests.jl
ParasPuneetSingh May 27, 2025
1c7a004
Update Project.toml
ParasPuneetSingh May 27, 2025
792d6cf
Update OptimizationODE.jl
ParasPuneetSingh May 28, 2025
dffe5f5
Update runtests.jl
ParasPuneetSingh May 28, 2025
a2d406e
Update OptimizationODE.jl
ParasPuneetSingh May 30, 2025
cb0668a
Update runtests.jl
ParasPuneetSingh May 30, 2025
c7a06c4
Merge branch 'SciML:master' into master
ParasPuneetSingh May 30, 2025
8e451e1
Update OptimizationODE.jl
ParasPuneetSingh May 30, 2025
9e03221
Update runtests.jl
ParasPuneetSingh May 30, 2025
f9e6a78
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 30, 2025
e53a9e9
Update lib/OptimizationODE/Project.toml
ChrisRackauckas May 30, 2025
18b0614
Update lib/OptimizationODE/Project.toml
ChrisRackauckas May 30, 2025
df75819
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 30, 2025
6aa89af
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 30, 2025
5c20429
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 30, 2025
962832b
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 30, 2025
b437cd8
Update lib/OptimizationODE/test/runtests.jl
ChrisRackauckas May 30, 2025
bcd6f75
Update lib/OptimizationODE/test/runtests.jl
ChrisRackauckas May 30, 2025
a10cb04
Update lib/OptimizationODE/test/runtests.jl
ChrisRackauckas May 30, 2025
6164774
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 30, 2025
b8a7fe4
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 30, 2025
8c2a16e
Update lib/OptimizationODE/test/runtests.jl
ChrisRackauckas May 30, 2025
72a4b62
Update lib/OptimizationODE/Project.toml
ChrisRackauckas May 30, 2025
f977f0d
Update Project.toml
ChrisRackauckas May 30, 2025
2267a46
Update lib/OptimizationODE/test/runtests.jl
ChrisRackauckas May 31, 2025
45d09aa
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 31, 2025
238ba5d
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 31, 2025
43c7ea7
Update Project.toml
ChrisRackauckas May 31, 2025
e68e640
Update lib/OptimizationODE/src/OptimizationODE.jl
ChrisRackauckas May 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
- OptimizationMultistartOptimization
- OptimizationNLopt
- OptimizationNOMAD
- OptimizationODE
- OptimizationOptimJL
- OptimizationOptimisers
- OptimizationPRIMA
Expand Down
24 changes: 24 additions & 0 deletions lib/OptimizationODE/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name = "OptimizationODE"
uuid = "dfa73e59-e644-4d8a-bf84-188d7ecb34e4"
authors = ["Paras Puneet Singh <[email protected]>"]
version = "0.1.0"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Optimization = "4"
Reexport = "1"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ADTypes", "Test"]
106 changes: 106 additions & 0 deletions lib/OptimizationODE/src/OptimizationODE.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
module OptimizationODE

using Reexport
@reexport using Optimization, SciMLBase
using OrdinaryDiffEq

export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent

struct ODEOptimizer{T, T2}
solver::T
dt::T2
end
ODEOptimizer(solver ; dt=nothing) = ODEOptimizer(solver, dt)

# Solver Constructors (users call these)
ODEGradientDescent(; dt) = ODEOptimizer(Euler(); dt)
RKChebyshevDescent() = ODEOptimizer(ROCK2())
RKAccelerated() = ODEOptimizer(Tsit5())
HighOrderDescent() = ODEOptimizer(Vern7())


SciMLBase.requiresbounds(::ODEOptimizer) = false
SciMLBase.allowsbounds(::ODEOptimizer) = false
SciMLBase.allowscallback(::ODEOptimizer) = true
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true
SciMLBase.requiresgradient(::ODEOptimizer) = true
SciMLBase.requireshessian(::ODEOptimizer) = false
SciMLBase.requiresconsjac(::ODEOptimizer) = false
SciMLBase.requiresconshess(::ODEOptimizer) = false


function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer, data=Optimization.DEFAULT_DATA;
callback=Optimization.DEFAULT_CALLBACK, progress=false,
maxiters=nothing, kwargs...)

return OptimizationCache(prob, opt, data; callback=callback, progress=progress,
maxiters=maxiters, kwargs...)
end

function SciMLBase.__solve(
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C}

dt = cache.opt.dt
maxit = get(cache.solver_args, :maxiters, 1000)

u0 = copy(cache.u0)
p = cache.p

if cache.f.grad === nothing
error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.")
end

function f!(du, u, p, t)
cache.f.grad(du, u, p)
@. du = -du
return nothing
end

ss_prob = SteadyStateProblem(f!, u0, p)

algorithm = DynamicSS(cache.opt.solver)

cb = cache.callback
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true
function condition(u, t, integrator)
true
end
function affect!(integrator)
u_now = integrator.u
state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p))
Optimization.callback_function(cb, state)
end
cb_struct = DiscreteCallback(condition, affect!)
callback = CallbackSet(cb_struct)
else
callback = nothing
end

solve_kwargs = Dict{Symbol, Any}(:callback => callback)
if !isnothing(maxit)
solve_kwargs[:maxiters] = maxit
end
if dt !== nothing
solve_kwargs[:dt] = dt
end

sol = solve(ss_prob, algorithm; solve_kwargs...)
has_destats = hasproperty(sol, :destats)
has_t = hasproperty(sol, :t) && !isempty(sol.t)

stats = Optimization.OptimizationStats(
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
time = has_t ? sol.t[end] : 0.0,
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
hevals = 0
)

SciMLBase.build_solution(cache, cache.opt, sol.u, cache.f(sol.u, p);
retcode = ReturnCode.Success,
stats = stats
)
end

end
44 changes: 44 additions & 0 deletions lib/OptimizationODE/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Test
using OptimizationODE, SciMLBase, ADTypes

@testset "OptimizationODE Tests" begin

function f(x, p)
return sum(abs2, x)
end

function g!(g, x, p)
@. g = 2 * x
end

x0 = [2.0, -3.0]
p = [5.0]

f_autodiff = OptimizationFunction(f, ADTypes.AutoForwardDiff())
prob_auto = OptimizationProblem(f_autodiff, x0, p)

for opt in (ODEGradientDescent(dt=0.01), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent())
sol = solve(prob_auto, opt; maxiters=50_000)
@test sol.u ≈ [0.0, 0.0] atol=1e-2
@test sol.objective ≈ 0.0 atol=1e-2
@test sol.retcode == ReturnCode.Success
end

f_manual = OptimizationFunction(f, SciMLBase.NoAD(); grad=g!)
prob_manual = OptimizationProblem(f_manual, x0)

for opt in (ODEGradientDescent(dt=0.01), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent())
sol = solve(prob_manual, opt; maxiters=50_000)
@test sol.u ≈ [0.0, 0.0] atol=1e-2
@test sol.objective ≈ 0.0 atol=1e-2
@test sol.retcode == ReturnCode.Success
end

f_fail = OptimizationFunction(f, SciMLBase.NoAD())
prob_fail = OptimizationProblem(f_fail, x0)

for opt in (ODEGradientDescent(dt=0.001), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent())
@test_throws ErrorException solve(prob_fail, opt; maxiters=20_000)
end

end
Loading