Skip to content

Commit 8e451e1

Browse files
Update OptimizationODE.jl
Redundant parameters removed, PRKChecbyshevDescent renamed as HighOrderDescent(uses Vern7).
1 parent c7a06c4 commit 8e451e1

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
module OptimizationODE
22

33
using Reexport
4-
@reexport using Optimization, Optimization.SciMLBase
4+
@reexport using Optimization, SciMLBase
55
using DifferentialEquations
66
using Optimization.LinearAlgebra
77

8-
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, PRKChebyshevDescent
8+
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
99

1010
struct ODEOptimizer{T}
11-
solver::Type{T}
11+
solver::T
1212
end
1313

1414
# Solver Constructors (users call these)
15-
ODEGradientDescent() = ODEOptimizer(Euler)
16-
RKChebyshevDescent() = ODEOptimizer(ROCK2)
17-
RKAccelerated() = ODEOptimizer(Tsit5)
18-
PRKChebyshevDescent() = ODEOptimizer(Vern7)
15+
ODEGradientDescent() = ODEOptimizer(Euler())
16+
RKChebyshevDescent() = ODEOptimizer(ROCK2())
17+
RKAccelerated() = ODEOptimizer(Tsit5())
18+
HighOrderDescent() = ODEOptimizer(Vern7())
1919

2020

2121
SciMLBase.requiresbounds(::ODEOptimizer) = false
@@ -29,21 +29,19 @@ SciMLBase.requiresconshess(::ODEOptimizer) = false
2929

3030

3131
function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer, data=Optimization.DEFAULT_DATA;
32-
η=0.1, dt=nothing, tmax=100.0, callback=Optimization.DEFAULT_CALLBACK, progress=false,
32+
dt=nothing, callback=Optimization.DEFAULT_CALLBACK, progress=false,
3333
maxiters=nothing, kwargs...)
3434

3535
return OptimizationCache(prob, opt, data;
36-
η=η, dt=dt, tmax=tmax, callback=callback, progress=progress,
36+
dt=dt, callback=callback, progress=progress,
3737
maxiters=maxiters, kwargs...)
3838
end
3939

4040
function SciMLBase.__solve(
4141
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
4242
) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C}
4343

44-
η = get(cache.solver_args, , 0.1)
4544
dt = get(cache.solver_args, :dt, nothing)
46-
tmax = get(cache.solver_args, :tmax, 100.0)
4745
maxit = get(cache.solver_args, :maxiters, 1000)
4846

4947
u0 = copy(cache.u0)
@@ -55,13 +53,13 @@ function SciMLBase.__solve(
5553

5654
function f!(du, u, p, t)
5755
cache.f.grad(du, u, p)
58-
@. du = -η * du
56+
@. du = -du
5957
return nothing
6058
end
6159

6260
ss_prob = SteadyStateProblem(f!, u0, p)
6361

64-
algorithm = DynamicSS(cache.opt.solver())
62+
algorithm = DynamicSS(cache.opt.solver)
6563

6664
cb = cache.callback
6765
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true
@@ -70,7 +68,7 @@ function SciMLBase.__solve(
7068
end
7169
function affect!(integrator)
7270
u_now = integrator.u
73-
state = Optimization.OptimizationState(u=u_now, objective=cache.f(u_now, p))
71+
state = Optimization.OptimizationState(u=u_now, objective=integrator(integrator.t, Val{1}))
7472
Optimization.callback_function(cb, state)
7573
end
7674
cb_struct = DiscreteCallback(condition, affect!)

0 commit comments

Comments
 (0)