@@ -2,58 +2,189 @@ module OptimizationODE
22
33using Reexport
44@reexport using Optimization, SciMLBase
5- using OrdinaryDiffEq, SteadyStateDiffEq
5+ using LinearAlgebra, ForwardDiff
6+
7+ using NonlinearSolve
8+ using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials
69
710export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
11+ export DAEOptimizer, DAEMassMatrix, DAEIndexing
12+
13+ struct ODEOptimizer{T}
14+ solver:: T
15+ end
16+
17+ ODEGradientDescent () = ODEOptimizer (Euler ())
18+ RKChebyshevDescent () = ODEOptimizer (ROCK2 ())
19+ RKAccelerated () = ODEOptimizer (Tsit5 ())
20+ HighOrderDescent () = ODEOptimizer (Vern7 ())
821
9- struct ODEOptimizer{T, T2 }
22+ struct DAEOptimizer{T }
1023 solver:: T
11- dt:: T2
1224end
13- ODEOptimizer (solver ; dt= nothing ) = ODEOptimizer (solver, dt)
1425
15- # Solver Constructors (users call these)
16- ODEGradientDescent (; dt) = ODEOptimizer (Euler (); dt)
17- RKChebyshevDescent () = ODEOptimizer (ROCK2 ())
18- RKAccelerated () = ODEOptimizer (Tsit5 ())
19- HighOrderDescent () = ODEOptimizer (Vern7 ())
26+ DAEMassMatrix () = DAEOptimizer (Rodas5 ())
27+ DAEIndexing () = DAEOptimizer (IDA ())
2028
2129
22- SciMLBase. requiresbounds (:: ODEOptimizer ) = false
23- SciMLBase. allowsbounds (:: ODEOptimizer ) = false
24- SciMLBase. allowscallback (:: ODEOptimizer ) = true
30+ SciMLBase. requiresbounds (:: ODEOptimizer ) = false
31+ SciMLBase. allowsbounds (:: ODEOptimizer ) = false
32+ SciMLBase. allowscallback (:: ODEOptimizer ) = true
2533SciMLBase. supports_opt_cache_interface (:: ODEOptimizer ) = true
26- SciMLBase. requiresgradient (:: ODEOptimizer ) = true
27- SciMLBase. requireshessian (:: ODEOptimizer ) = false
28- SciMLBase. requiresconsjac (:: ODEOptimizer ) = false
29- SciMLBase. requiresconshess (:: ODEOptimizer ) = false
34+ SciMLBase. requiresgradient (:: ODEOptimizer ) = true
35+ SciMLBase. requireshessian (:: ODEOptimizer ) = false
36+ SciMLBase. requiresconsjac (:: ODEOptimizer ) = false
37+ SciMLBase. requiresconshess (:: ODEOptimizer ) = false
38+
39+
40+ SciMLBase. requiresbounds (:: DAEOptimizer ) = false
41+ SciMLBase. allowsbounds (:: DAEOptimizer ) = false
42+ SciMLBase. allowsconstraints (:: DAEOptimizer ) = true
43+ SciMLBase. allowscallback (:: DAEOptimizer ) = true
44+ SciMLBase. supports_opt_cache_interface (:: DAEOptimizer ) = true
45+ SciMLBase. requiresgradient (:: DAEOptimizer ) = true
46+ SciMLBase. requireshessian (:: DAEOptimizer ) = false
47+ SciMLBase. requiresconsjac (:: DAEOptimizer ) = true
48+ SciMLBase. requiresconshess (:: DAEOptimizer ) = false
3049
3150
3251function SciMLBase. __init (prob:: OptimizationProblem , opt:: ODEOptimizer ;
33- callback= Optimization. DEFAULT_CALLBACK, progress= false ,
52+ callback= Optimization. DEFAULT_CALLBACK, progress= false , dt = nothing ,
3453 maxiters= nothing , kwargs... )
35-
36- return OptimizationCache (prob, opt; callback= callback, progress= progress,
54+ return OptimizationCache (prob, opt; callback= callback, progress= progress, dt= dt,
3755 maxiters= maxiters, kwargs... )
3856end
3957
40- function SciMLBase. __solve (
41- cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
42- ) where {F,RC,LB,UB,LC,UC,S,O<: ODEOptimizer ,D,P,C}
58+ function SciMLBase. __init (prob:: OptimizationProblem , opt:: DAEOptimizer ;
59+ callback= Optimization. DEFAULT_CALLBACK, progress= false , dt= nothing ,
60+ maxiters= nothing , differential_vars= nothing , kwargs... )
61+ return OptimizationCache (prob, opt; callback= callback, progress= progress, dt= dt,
62+ maxiters= maxiters, differential_vars= differential_vars, kwargs... )
63+ end
64+
65+
66+ function solve_constrained_root (cache, u0, p)
67+ n = length (u0)
68+ cons_vals = cache. f. cons (u0, p)
69+ m = length (cons_vals)
70+ function resid! (res, u)
71+ temp = similar (u)
72+ f_mass! (temp, u, p, 0.0 )
73+ res .= temp
74+ end
75+ u0_ext = vcat (u0, zeros (m))
76+ prob_nl = NonlinearProblem (resid!, u0_ext, p)
77+ sol_nl = solve (prob_nl, Newton (); tol = 1e-8 , maxiters = 100000 ,
78+ callback = cache. callback, progress = get (cache. solver_args, :progress , false ))
79+ u_ext = sol_nl. u
80+ return u_ext[1 : n], sol_nl. retcode
81+ end
82+
83+
84+ function get_solver_type (opt:: DAEOptimizer )
85+ if opt. solver isa Union{Rodas5, RadauIIA5, ImplicitEuler, Trapezoid}
86+ return :mass_matrix
87+ else
88+ return :indexing
89+ end
90+ end
4391
44- dt = cache. opt. dt
45- maxit = get (cache. solver_args, :maxiters , 1000 )
92+ function handle_parameters (p)
93+ if p isa SciMLBase. NullParameters
94+ return Float64[]
95+ else
96+ return p
97+ end
98+ end
99+
100+ function setup_progress_callback (cache, solve_kwargs)
101+ if get (cache. solver_args, :progress , false )
102+ condition = (u, t, integrator) -> true
103+ affect! = (integrator) -> begin
104+ u_opt = integrator. u isa AbstractArray ? integrator. u : integrator. u. u
105+ cache. solver_args[:callback ](u_opt, integrator. p, integrator. t)
106+ end
107+ cb = DiscreteCallback (condition, affect!)
108+ solve_kwargs[:callback ] = cb
109+ end
110+ return solve_kwargs
111+ end
46112
113+ function finite_difference_jacobian (f, x; ϵ = 1e-8 )
114+ n = length (x)
115+ fx = f (x)
116+ if fx === nothing
117+ return zeros (eltype (x), 0 , n)
118+ elseif isa (fx, Number)
119+ J = zeros (eltype (fx), 1 , n)
120+ for j in 1 : n
121+ xj = copy (x)
122+ xj[j] += ϵ
123+ diff = f (xj)
124+ if diff === nothing
125+ diffval = zero (eltype (fx))
126+ else
127+ diffval = diff - fx
128+ end
129+ J[1 , j] = diffval / ϵ
130+ end
131+ return J
132+ else
133+ m = length (fx)
134+ J = zeros (eltype (fx), m, n)
135+ for j in 1 : n
136+ xj = copy (x)
137+ xj[j] += ϵ
138+ fxj = f (xj)
139+ if fxj === nothing
140+ @inbounds for i in 1 : m
141+ J[i, j] = - fx[i] / ϵ
142+ end
143+ else
144+ @inbounds for i in 1 : m
145+ J[i, j] = (fxj[i] - fx[i]) / ϵ
146+ end
147+ end
148+ end
149+ return J
150+ end
151+ end
152+
153+ function SciMLBase. __solve (
154+ cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
155+ ) where {F,RC,LB,UB,LC,UC,S,O<: Union{ODEOptimizer,DAEOptimizer} ,D,P,C}
156+
157+ dt = get (cache. solver_args, :dt , nothing )
158+ maxit = get (cache. solver_args, :maxiters , nothing )
159+ differential_vars = get (cache. solver_args, :differential_vars , nothing )
47160 u0 = copy (cache. u0)
48- p = cache. p
161+ p = handle_parameters ( cache. p) # Properly handle NullParameters
49162
163+ if cache. opt isa ODEOptimizer
164+ return solve_ode (cache, dt, maxit, u0, p)
165+ else
166+ solver_method = get_solver_type (cache. opt)
167+ if solver_method == :mass_matrix
168+ return solve_dae_mass_matrix (cache, dt, maxit, u0, p)
169+ else
170+ return solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars)
171+ end
172+ end
173+ end
174+
175+ function solve_ode (cache, dt, maxit, u0, p)
50176 if cache. f. grad === nothing
51177 error (" ODEOptimizer requires a gradient. Please provide a function with `grad` defined." )
52178 end
53179
54180 function f! (du, u, p, t)
55- cache. f. grad (du, u, p)
56- @. du = - du
181+ grad_vec = similar (u)
182+ if isempty (p)
183+ cache. f. grad (grad_vec, u)
184+ else
185+ cache. f. grad (grad_vec, u, p)
186+ end
187+ @. du = - grad_vec
57188 return nothing
58189 end
59190
@@ -62,14 +193,11 @@ function SciMLBase.__solve(
62193 algorithm = DynamicSS (cache. opt. solver)
63194
64195 cb = cache. callback
65- if cb != Optimization. DEFAULT_CALLBACK || get (cache. solver_args,:progress ,false ) === true
66- function condition (u, t, integrator)
67- true
68- end
196+ if cb != Optimization. DEFAULT_CALLBACK || get (cache. solver_args,:progress ,false )
197+ function condition (u, t, integrator) true end
69198 function affect! (integrator)
70199 u_now = integrator. u
71- state = Optimization. OptimizationState (u= u_now, objective= cache. f (integrator. u, integrator. p))
72- Optimization. callback_function (cb, state)
200+ cache. callback (u_now, integrator. p, integrator. t)
73201 end
74202 cb_struct = DiscreteCallback (condition, affect!)
75203 callback = CallbackSet (cb_struct)
@@ -86,21 +214,154 @@ function SciMLBase.__solve(
86214 end
87215
88216 sol = solve (ss_prob, algorithm; solve_kwargs... )
89- has_destats = hasproperty (sol, :destats )
90- has_t = hasproperty (sol, :t ) && ! isempty (sol. t)
217+ has_destats = hasproperty (sol, :destats )
218+ has_t = hasproperty (sol, :t ) && ! isempty (sol. t)
91219
92- stats = Optimization. OptimizationStats (
93- iterations = has_destats ? get (sol. destats, :iters , 10 ) : (has_t ? length (sol. t) - 1 : 10 ),
94- time = has_t ? sol. t[end ] : 0.0 ,
95- fevals = has_destats ? get (sol. destats, :f_calls , 0 ) : 0 ,
96- gevals = has_destats ? get (sol. destats, :iters , 0 ) : 0 ,
97- hevals = 0
98- )
220+ stats = Optimization. OptimizationStats (
221+ iterations = has_destats ? get (sol. destats, :iters , 10 ) : (has_t ? length (sol. t) - 1 : 10 ),
222+ time = has_t ? sol. t[end ] : 0.0 ,
223+ fevals = has_destats ? get (sol. destats, :f_calls , 0 ) : 0 ,
224+ gevals = has_destats ? get (sol. destats, :iters , 0 ) : 0 ,
225+ hevals = 0
226+ )
99227
100228 SciMLBase. build_solution (cache, cache. opt, sol. u, cache. f (sol. u, p);
101229 retcode = ReturnCode. Success,
102230 stats = stats
103231 )
104232end
105233
234+ function solve_dae_mass_matrix (cache, dt, maxit, u0, p)
235+ if cache. f. cons === nothing
236+ return solve_ode (cache, dt, maxit, u0, p)
237+ end
238+ x= u0
239+ cons_vals = cache. f. cons (x, p)
240+ n = length (u0)
241+ m = length (cons_vals)
242+ u0_extended = vcat (u0, zeros (m))
243+ M = zeros (n + m, n + m)
244+ M[1 : n, 1 : n] = I (n)
245+
246+ function f_mass! (du, u, p_, t)
247+ x = @view u[1 : n]
248+ λ = @view u[n+ 1 : end ]
249+ grad_f = similar (x)
250+ if cache. f. grad != = nothing
251+ cache. f. grad (grad_f, x, p_)
252+ else
253+ grad_f .= ForwardDiff. gradient (z -> cache. f. f (z, p_), x)
254+ end
255+ J = Matrix {eltype(x)} (undef, m, n)
256+ if cache. f. cons_j != = nothing
257+ cache. f. cons_j (J, x)
258+ else
259+ J .= finite_difference_jacobian (z -> cache. f. cons (z, p_), x)
260+ end
261+ @. du[1 : n] = - grad_f - (J' * λ)
262+ consv = cache. f. cons (x, p_)
263+ if consv === nothing
264+ fill! (du[n+ 1 : end ], zero (eltype (x)))
265+ else
266+ if isa (consv, Number)
267+ @assert m == 1
268+ du[n+ 1 ] = consv
269+ else
270+ @assert length (consv) == m
271+ @. du[n+ 1 : end ] = consv
272+ end
273+ end
274+ return nothing
275+ end
276+
277+ if m == 0
278+ optf = ODEFunction (f_mass!, mass_matrix = I (n))
279+ prob = ODEProblem (optf, u0, (0.0 , 1.0 ), p)
280+ return solve (prob, HighOrderDescent (); dt= dt, maxiters= maxit)
281+ end
282+
283+ ss_prob = SteadyStateProblem (ODEFunction (f_mass!, mass_matrix = M), u0_extended, p)
284+
285+ solve_kwargs = setup_progress_callback (cache, Dict ())
286+ if maxit != = nothing ; solve_kwargs[:maxiters ] = maxit; end
287+ if dt != = nothing ; solve_kwargs[:dt ] = dt; end
288+
289+ sol = solve (ss_prob, DynamicSS (cache. opt. solver); solve_kwargs... )
290+ # if sol.retcode ≠ ReturnCode.Success
291+ # # you may still accept Default or warn
292+ # end
293+ u_ext = sol. u
294+ u_final = u_ext[1 : n]
295+ return SciMLBase. build_solution (cache, cache. opt, u_final, cache. f (u_final, p);
296+ retcode = sol. retcode)
106297end
298+
299+
300+ function solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars)
301+ if cache. f. cons === nothing
302+ return solve_ode (cache, dt, maxit, u0, p)
303+ end
304+ x= u0
305+ cons_vals = cache. f. cons (x, p)
306+ n = length (u0)
307+ m = length (cons_vals)
308+ u0_ext = vcat (u0, zeros (m))
309+ du0_ext = zeros (n + m)
310+
311+ if differential_vars === nothing
312+ differential_vars = vcat (fill (true , n), fill (false , m))
313+ else
314+ if length (differential_vars) == n
315+ differential_vars = vcat (differential_vars, fill (false , m))
316+ elseif length (differential_vars) == n + m
317+ # use as is
318+ else
319+ error (" differential_vars length must be number of variables ($n ) or extended size ($(n+ m) )" )
320+ end
321+ end
322+
323+ function dae_residual! (res, du, u, p_, t)
324+ x = @view u[1 : n]
325+ λ = @view u[n+ 1 : end ]
326+ du_x = @view du[1 : n]
327+ grad_f = similar (x)
328+ cache. f. grad (grad_f, x, p_)
329+ J = zeros (m, n)
330+ if cache. f. cons_j != = nothing
331+ cache. f. cons_j (J, x)
332+ else
333+ J .= finite_difference_jacobian (z -> cache. f. cons (z,p_), x)
334+ end
335+ @. res[1 : n] = du_x + grad_f + J' * λ
336+ consv = cache. f. cons (x, p_)
337+ @. res[n+ 1 : end ] = consv
338+ return nothing
339+ end
340+
341+ if m == 0
342+ optf = ODEFunction (dae_residual!, differential_vars = differential_vars)
343+ prob = ODEProblem (optf, du0_ext, (0.0 , 1.0 ), p)
344+ return solve (prob, HighOrderDescent (); dt= dt, maxiters= maxit)
345+ end
346+
347+ tspan = (0.0 , 10.0 )
348+ prob = DAEProblem (dae_residual!, du0_ext, u0_ext, tspan, p;
349+ differential_vars = differential_vars)
350+
351+ solve_kwargs = setup_progress_callback (cache, Dict ())
352+ if maxit != = nothing ; solve_kwargs[:maxiters ] = maxit; end
353+ if dt != = nothing ; solve_kwargs[:dt ] = dt; end
354+ if hasfield (typeof (cache. opt. solver), :initializealg )
355+ solve_kwargs[:initializealg ] = BrownFullBasicInit ()
356+ end
357+
358+ sol = solve (prob, cache. opt. solver; solve_kwargs... )
359+ u_ext = sol. u
360+ u_final = u_ext[end ][1 : n]
361+
362+ return SciMLBase. build_solution (cache, cache. opt, u_final, cache. f (u_final, p);
363+ retcode = sol. retcode)
364+ end
365+
366+
367+ end
0 commit comments