@@ -3,61 +3,106 @@ module OptimizationODE
33using Reexport
44@reexport using Optimization, Optimization. SciMLBase
55using DifferentialEquations
6+ using Optimization. LinearAlgebra
67
78export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, PRKChebyshevDescent
89
9- abstract type AbstractODEOptimizer end
10-
11- struct ODEOptimizer{T} <: AbstractODEOptimizer
12- alg:: T
10+ struct ODEOptimizer{T}
11+ solver:: Type{T}
1312end
1413
15-
16- const ODEGradientDescent = ODEOptimizer (Euler () )
17- const RKChebyshevDescent = ODEOptimizer (ROCK2 () )
18- const RKAccelerated = ODEOptimizer (Tsit5 () )
19- const PRKChebyshevDescent = ODEOptimizer (ROCK4 () )
14+ # Solver Constructors (users call these)
15+ ODEGradientDescent () = ODEOptimizer (Euler)
16+ RKChebyshevDescent () = ODEOptimizer (ROCK2)
17+ RKAccelerated () = ODEOptimizer (Tsit5)
18+ PRKChebyshevDescent () = ODEOptimizer (Vern7 )
2019
2120
21+ SciMLBase. requiresbounds (:: ODEOptimizer ) = false
22+ SciMLBase. allowsbounds (:: ODEOptimizer ) = false
23+ SciMLBase. allowscallback (:: ODEOptimizer ) = true
2224SciMLBase. supports_opt_cache_interface (:: ODEOptimizer ) = true
23- SciMLBase. requiresgradient (:: ODEOptimizer ) = true
24-
25- function Optimization. __map_optimizer_args (cache:: OptimizationCache , opt:: ODEOptimizer ;
26- dt:: Real = 0.01 ,
27- maxiters:: Integer = 100 ,
28- callback = nothing ,
29- progress = false ,
30- kwargs...
31- )
32- cache. meta[:dt ] = dt
33- cache. meta[:maxiters ]= maxiters
34- cache. meta[:callback ]= callback
35- cache. meta[:progress ]= progress
36- return nothing
25+ SciMLBase. requiresgradient (:: ODEOptimizer ) = true
26+ SciMLBase. requireshessian (:: ODEOptimizer ) = false
27+ SciMLBase. requiresconsjac (:: ODEOptimizer ) = false
28+ SciMLBase. requiresconshess (:: ODEOptimizer ) = false
29+
30+
31+ function SciMLBase. __init (prob:: OptimizationProblem , opt:: ODEOptimizer , data= Optimization. DEFAULT_DATA;
32+ η= 0.1 , dt= nothing , tmax= 100.0 , callback= Optimization. DEFAULT_CALLBACK, progress= false ,
33+ maxiters= nothing , kwargs... )
34+
35+ return OptimizationCache (prob, opt, data;
36+ η= η, dt= dt, tmax= tmax, callback= callback, progress= progress,
37+ maxiters= maxiters, kwargs... )
3738end
3839
3940function SciMLBase. __solve (
40- cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,<:ODEOptimizer,D,P,C}
41- ) where {F,RC,LB,UB,LC,UC,S,D,P,C}
42- dt = cache. solver_args[:dt ]
43- maxiters = cache. solver_args[:maxiters ]
44- tspan = (0.0 , maxiters * dt)
45-
46- alg = cache. opt. alg
47-
48- prob = SteadyStateProblem (
49- (du, u, p, t) -> begin
50- cache. f. grad (du, u, cache. p)
51- du .*= - 1
52- end ,
53- cache. u0,
54- cache. p
55- )
41+ cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
42+ ) where {F,RC,LB,UB,LC,UC,S,O<: ODEOptimizer ,D,P,C}
43+
44+ η = get (cache. solver_args, :η , 0.1 )
45+ dt = get (cache. solver_args, :dt , nothing )
46+ tmax = get (cache. solver_args, :tmax , 100.0 )
47+ maxit = get (cache. solver_args, :maxiters , 1000 )
48+
49+ u0 = copy (cache. u0)
50+ p = cache. p
5651
57- sol = solve (prob, DynamicSS (alg); dt= dt)
52+ if cache. f. grad === nothing
53+ error (" ODEOptimizer requires a gradient. Please provide a function with `grad` defined." )
54+ end
5855
59- return SciMLBase. build_solution (cache, cache. opt, sol. u,
60- sol. resid; original = sol, retcode = sol. retcode)
56+ function f! (du, u, p, t)
57+ cache. f. grad (du, u, p)
58+ @. du = - η * du
59+ return nothing
60+ end
61+
62+ ss_prob = SteadyStateProblem (f!, u0, p)
63+
64+ algorithm = DynamicSS (cache. opt. solver ())
65+
66+ cb = cache. callback
67+ if cb != Optimization. DEFAULT_CALLBACK || get (cache. solver_args,:progress ,false ) === true
68+ function condition (u, t, integrator)
69+ true
70+ end
71+ function affect! (integrator)
72+ u_now = integrator. u
73+ state = Optimization. OptimizationState (u= u_now, objective= cache. f (u_now, p))
74+ Optimization. callback_function (cb, state)
75+ end
76+ cb_struct = DiscreteCallback (condition, affect!)
77+ callback = CallbackSet (cb_struct)
78+ else
79+ callback = nothing
80+ end
81+
82+ solve_kwargs = Dict {Symbol, Any} (:callback => callback)
83+ if ! isnothing (maxit)
84+ solve_kwargs[:maxiters ] = maxit
85+ end
86+ if dt != = nothing
87+ solve_kwargs[:dt ] = dt
88+ end
89+
90+ sol = solve (ss_prob, algorithm; solve_kwargs... )
91+ has_destats = hasproperty (sol, :destats )
92+ has_t = hasproperty (sol, :t ) && ! isempty (sol. t)
93+
94+ stats = Optimization. OptimizationStats (
95+ iterations = has_destats ? get (sol. destats, :iters , 10 ) : (has_t ? length (sol. t) - 1 : 10 ),
96+ time = has_t ? sol. t[end ] : 0.0 ,
97+ fevals = has_destats ? get (sol. destats, :f_calls , 0 ) : 0 ,
98+ gevals = has_destats ? get (sol. destats, :iters , 0 ) : 0 ,
99+ hevals = 0
100+ )
101+
102+ SciMLBase. build_solution (cache, cache. opt, sol. u, cache. f (sol. u, p);
103+ retcode = ReturnCode. Success,
104+ stats = stats
105+ )
61106end
62107
63108end
0 commit comments