@@ -5,10 +5,10 @@ using Reexport
55using LinearAlgebra, ForwardDiff
66
77using NonlinearSolve
8- using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials
8+ using OrdinaryDiffEq, SteadyStateDiffEq, Sundials
99
1010export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
11- export DAEOptimizer, DAEMassMatrix, DAEIndexing
11+ export DAEOptimizer, DAEMassMatrix
1212
1313struct ODEOptimizer{T}
1414 solver:: T
@@ -23,8 +23,7 @@ struct DAEOptimizer{T}
2323 solver:: T
2424end
2525
26- DAEMassMatrix () = DAEOptimizer (Rosenbrock23 (autodiff = false ))
27- DAEIndexing () = DAEOptimizer (IDA ())
26+ DAEMassMatrix () = DAEOptimizer (Rodas5P (autodiff = false ))
2827
2928
3029SciMLBase. requiresbounds (:: ODEOptimizer ) = false
@@ -62,29 +61,6 @@ function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
6261 maxiters= maxiters, differential_vars= differential_vars, kwargs... )
6362end
6463
65-
66- function handle_parameters (p)
67- if p isa SciMLBase. NullParameters
68- return Float64[]
69- else
70- return p
71- end
72- end
73-
74- function setup_progress_callback (cache, solve_kwargs)
75- if get (cache. solver_args, :progress , false )
76- condition = (u, t, integrator) -> true
77- affect! = (integrator) -> begin
78- u_opt = integrator. u isa AbstractArray ? integrator. u : integrator. u. u
79- cache. solver_args[:callback ](u_opt, integrator. p, integrator. t)
80- end
81- cb = DiscreteCallback (condition, affect!)
82- solve_kwargs[:callback ] = cb
83- end
84- return solve_kwargs
85- end
86-
87-
8864function SciMLBase. __solve (
8965 cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
9066 ) where {F,RC,LB,UB,LC,UC,S,O<: Union{ODEOptimizer,DAEOptimizer} ,D,P,C}
@@ -93,15 +69,15 @@ function SciMLBase.__solve(
9369 maxit = get (cache. solver_args, :maxiters , nothing )
9470 differential_vars = get (cache. solver_args, :differential_vars , nothing )
9571 u0 = copy (cache. u0)
96- p = handle_parameters ( cache. p) # Properly handle NullParameters
72+ p = cache. p # Properly handle NullParameters
9773
9874 if cache. opt isa ODEOptimizer
9975 return solve_ode (cache, dt, maxit, u0, p)
10076 else
101- if cache. opt. solver == Rosenbrock23 (autodiff = false )
102- return solve_dae_mass_matrix (cache, dt, maxit, u0, p)
77+ if cache. opt. solver isa SciMLBase . AbstractDAEAlgorithm
78+ return solve_dae_implicit (cache, dt, maxit, u0, p, differential_vars )
10379 else
104- return solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars )
80+ return solve_dae_mass_matrix (cache, dt, maxit, u0, p)
10581 end
10682 end
10783end
@@ -112,41 +88,37 @@ function solve_ode(cache, dt, maxit, u0, p)
11288 end
11389
11490 function f! (du, u, p, t)
115- grad_vec = similar (u)
116- if isempty (p)
117- cache. f. grad (grad_vec, u)
118- else
119- cache. f. grad (grad_vec, u, p)
120- end
121- @. du = - grad_vec
91+ cache. f. grad (du, u, p)
92+ @. du = - du
12293 return nothing
12394 end
12495
12596 ss_prob = SteadyStateProblem (f!, u0, p)
12697
12798 algorithm = DynamicSS (cache. opt. solver)
12899
129- cb = cache. callback
130- if cb != Optimization . DEFAULT_CALLBACK || get (cache . solver_args, :progress , false )
131- function condition (u, t, integrator) true end
132- function affect! ( integrator)
133- u_now = integrator. u
134- cache. callback (u_now, integrator. p, integrator . t )
100+ if cache. callback != = Optimization . DEFAULT_CALLBACK
101+ condition = (u, t, integrator) -> true
102+ affect! = ( integrator) -> begin
103+ u_opt = integrator. u isa AbstractArray ? integrator . u : integrator . u . u
104+ l = cache . f ( integrator. u, integrator . p)
105+ cache. callback (integrator. u, l )
135106 end
136- cb_struct = DiscreteCallback (condition, affect!)
137- callback = CallbackSet (cb_struct )
107+ cb = DiscreteCallback (condition, affect!)
108+ solve_kwargs = Dict {Symbol, Any} ( :callback => cb )
138109 else
139- callback = nothing
110+ solve_kwargs = Dict {Symbol, Any} ()
140111 end
141-
142- solve_kwargs = Dict {Symbol, Any} (:callback => callback)
112+
143113 if ! isnothing (maxit)
144114 solve_kwargs[:maxiters ] = maxit
145115 end
146116 if dt != = nothing
147117 solve_kwargs[:dt ] = dt
148118 end
149119
120+ solve_kwargs[:progress ] = cache. progress
121+
150122 sol = solve (ss_prob, algorithm; solve_kwargs... )
151123 has_destats = hasproperty (sol, :destats )
152124 has_t = hasproperty (sol, :t ) && ! isempty (sol. t)
@@ -218,7 +190,7 @@ function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
218190end
219191
220192
221- function solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars)
193+ function solve_dae_implicit (cache, dt, maxit, u0, p, differential_vars)
222194 if cache. f. cons === nothing
223195 return solve_ode (cache, dt, maxit, u0, p)
224196 end
0 commit comments