11module OptimizationODE
22
33using Reexport
4- @reexport using Optimization, Optimization . SciMLBase
4+ @reexport using Optimization, SciMLBase
55using DifferentialEquations
66using Optimization. LinearAlgebra
77
8- export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, PRKChebyshevDescent
8+ export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
99
1010struct ODEOptimizer{T}
11- solver:: Type{T}
11+ solver:: T
1212end
1313
1414# Solver Constructors (users call these)
15- ODEGradientDescent () = ODEOptimizer (Euler)
16- RKChebyshevDescent () = ODEOptimizer (ROCK2)
17- RKAccelerated () = ODEOptimizer (Tsit5)
18- PRKChebyshevDescent () = ODEOptimizer (Vern7)
15+ ODEGradientDescent () = ODEOptimizer (Euler () )
16+ RKChebyshevDescent () = ODEOptimizer (ROCK2 () )
17+ RKAccelerated () = ODEOptimizer (Tsit5 () )
18+ HighOrderDescent () = ODEOptimizer (Vern7 () )
1919
2020
2121SciMLBase. requiresbounds (:: ODEOptimizer ) = false
@@ -29,21 +29,19 @@ SciMLBase.requiresconshess(::ODEOptimizer) = false
2929
3030
3131function SciMLBase. __init (prob:: OptimizationProblem , opt:: ODEOptimizer , data= Optimization. DEFAULT_DATA;
32- η = 0.1 , dt= nothing , tmax = 100.0 , callback= Optimization. DEFAULT_CALLBACK, progress= false ,
32+ dt= nothing , callback= Optimization. DEFAULT_CALLBACK, progress= false ,
3333 maxiters= nothing , kwargs... )
3434
3535 return OptimizationCache (prob, opt, data;
36- η = η, dt= dt, tmax = tmax , callback= callback, progress= progress,
36+ dt= dt, callback= callback, progress= progress,
3737 maxiters= maxiters, kwargs... )
3838end
3939
4040function SciMLBase. __solve (
4141 cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
4242 ) where {F,RC,LB,UB,LC,UC,S,O<: ODEOptimizer ,D,P,C}
4343
44- η = get (cache. solver_args, :η , 0.1 )
4544 dt = get (cache. solver_args, :dt , nothing )
46- tmax = get (cache. solver_args, :tmax , 100.0 )
4745 maxit = get (cache. solver_args, :maxiters , 1000 )
4846
4947 u0 = copy (cache. u0)
@@ -55,13 +53,13 @@ function SciMLBase.__solve(
5553
5654 function f! (du, u, p, t)
5755 cache. f. grad (du, u, p)
58- @. du = - η * du
56+ @. du = - du
5957 return nothing
6058 end
6159
6260 ss_prob = SteadyStateProblem (f!, u0, p)
6361
64- algorithm = DynamicSS (cache. opt. solver () )
62+ algorithm = DynamicSS (cache. opt. solver)
6563
6664 cb = cache. callback
6765 if cb != Optimization. DEFAULT_CALLBACK || get (cache. solver_args,:progress ,false ) === true
@@ -70,7 +68,7 @@ function SciMLBase.__solve(
7068 end
7169 function affect! (integrator)
7270 u_now = integrator. u
73- state = Optimization. OptimizationState (u= u_now, objective= cache . f (u_now, p ))
71+ state = Optimization. OptimizationState (u= u_now, objective= integrator (integrator . t, Val{ 1 } ))
7472 Optimization. callback_function (cb, state)
7573 end
7674 cb_struct = DiscreteCallback (condition, affect!)
0 commit comments