Skip to content

Commit a2d406e

Browse files
Update OptimizationODE.jl
Add callback, progress, struct accepts type instead of insatance, maxiters passed.
1 parent dffe5f5 commit a2d406e

File tree

1 file changed

+87
-42
lines changed

1 file changed

+87
-42
lines changed

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,61 +3,106 @@ module OptimizationODE
33
using Reexport
44
@reexport using Optimization, Optimization.SciMLBase
55
using DifferentialEquations
6+
using Optimization.LinearAlgebra
67

78
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, PRKChebyshevDescent
89

9-
abstract type AbstractODEOptimizer end
10-
11-
struct ODEOptimizer{T} <: AbstractODEOptimizer
12-
alg::T
10+
struct ODEOptimizer{T}
11+
solver::Type{T}
1312
end
1413

15-
16-
const ODEGradientDescent = ODEOptimizer(Euler())
17-
const RKChebyshevDescent = ODEOptimizer(ROCK2())
18-
const RKAccelerated = ODEOptimizer(Tsit5())
19-
const PRKChebyshevDescent = ODEOptimizer(ROCK4())
14+
# Solver Constructors (users call these)
15+
ODEGradientDescent() = ODEOptimizer(Euler)
16+
RKChebyshevDescent() = ODEOptimizer(ROCK2)
17+
RKAccelerated() = ODEOptimizer(Tsit5)
18+
PRKChebyshevDescent() = ODEOptimizer(Vern7)
2019

2120

21+
SciMLBase.requiresbounds(::ODEOptimizer) = false
22+
SciMLBase.allowsbounds(::ODEOptimizer) = false
23+
SciMLBase.allowscallback(::ODEOptimizer) = true
2224
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true
23-
SciMLBase.requiresgradient(::ODEOptimizer) = true
24-
25-
function Optimization.__map_optimizer_args(cache::OptimizationCache, opt::ODEOptimizer;
26-
dt::Real = 0.01,
27-
maxiters::Integer = 100,
28-
callback = nothing,
29-
progress = false,
30-
kwargs...
31-
)
32-
cache.meta[:dt] = dt
33-
cache.meta[:maxiters]= maxiters
34-
cache.meta[:callback]= callback
35-
cache.meta[:progress]= progress
36-
return nothing
25+
SciMLBase.requiresgradient(::ODEOptimizer) = true
26+
SciMLBase.requireshessian(::ODEOptimizer) = false
27+
SciMLBase.requiresconsjac(::ODEOptimizer) = false
28+
SciMLBase.requiresconshess(::ODEOptimizer) = false
29+
30+
31+
function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer, data=Optimization.DEFAULT_DATA;
32+
η=0.1, dt=nothing, tmax=100.0, callback=Optimization.DEFAULT_CALLBACK, progress=false,
33+
maxiters=nothing, kwargs...)
34+
35+
return OptimizationCache(prob, opt, data;
36+
η=η, dt=dt, tmax=tmax, callback=callback, progress=progress,
37+
maxiters=maxiters, kwargs...)
3738
end
3839

3940
function SciMLBase.__solve(
40-
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,<:ODEOptimizer,D,P,C}
41-
) where {F,RC,LB,UB,LC,UC,S,D,P,C}
42-
dt = cache.solver_args[:dt]
43-
maxiters = cache.solver_args[:maxiters]
44-
tspan = (0.0, maxiters * dt)
45-
46-
alg = cache.opt.alg
47-
48-
prob = SteadyStateProblem(
49-
(du, u, p, t) -> begin
50-
cache.f.grad(du, u, cache.p)
51-
du .*= -1
52-
end,
53-
cache.u0,
54-
cache.p
55-
)
41+
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
42+
) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C}
43+
44+
η = get(cache.solver_args, , 0.1)
45+
dt = get(cache.solver_args, :dt, nothing)
46+
tmax = get(cache.solver_args, :tmax, 100.0)
47+
maxit = get(cache.solver_args, :maxiters, 1000)
48+
49+
u0 = copy(cache.u0)
50+
p = cache.p
5651

57-
sol = solve(prob, DynamicSS(alg); dt=dt)
52+
if cache.f.grad === nothing
53+
error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.")
54+
end
5855

59-
return SciMLBase.build_solution(cache, cache.opt, sol.u,
60-
sol.resid; original = sol, retcode = sol.retcode)
56+
function f!(du, u, p, t)
57+
cache.f.grad(du, u, p)
58+
@. du = -η * du
59+
return nothing
60+
end
61+
62+
ss_prob = SteadyStateProblem(f!, u0, p)
63+
64+
algorithm = DynamicSS(cache.opt.solver())
65+
66+
cb = cache.callback
67+
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true
68+
function condition(u, t, integrator)
69+
true
70+
end
71+
function affect!(integrator)
72+
u_now = integrator.u
73+
state = Optimization.OptimizationState(u=u_now, objective=cache.f(u_now, p))
74+
Optimization.callback_function(cb, state)
75+
end
76+
cb_struct = DiscreteCallback(condition, affect!)
77+
callback = CallbackSet(cb_struct)
78+
else
79+
callback = nothing
80+
end
81+
82+
solve_kwargs = Dict{Symbol, Any}(:callback => callback)
83+
if !isnothing(maxit)
84+
solve_kwargs[:maxiters] = maxit
85+
end
86+
if dt !== nothing
87+
solve_kwargs[:dt] = dt
88+
end
89+
90+
sol = solve(ss_prob, algorithm; solve_kwargs...)
91+
has_destats = hasproperty(sol, :destats)
92+
has_t = hasproperty(sol, :t) && !isempty(sol.t)
93+
94+
stats = Optimization.OptimizationStats(
95+
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
96+
time = has_t ? sol.t[end] : 0.0,
97+
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
98+
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
99+
hevals = 0
100+
)
101+
102+
SciMLBase.build_solution(cache, cache.opt, sol.u, cache.f(sol.u, p);
103+
retcode = ReturnCode.Success,
104+
stats = stats
105+
)
61106
end
62107

63108
end

0 commit comments

Comments
 (0)