5656
5757function SciMLBase. __init (prob:: OptimizationProblem , opt:: DAEOptimizer ;
5858 callback= Optimization. DEFAULT_CALLBACK, progress= false , dt= nothing ,
59- maxiters= nothing , differential_vars = nothing , kwargs... )
59+ maxiters= nothing , kwargs... )
6060 return OptimizationCache (prob, opt; callback= callback, progress= progress, dt= dt,
61- maxiters= maxiters, differential_vars = differential_vars, kwargs... )
61+ maxiters= maxiters, kwargs... )
6262end
6363
6464function SciMLBase. __solve (
@@ -67,15 +67,14 @@ function SciMLBase.__solve(
6767
6868 dt = get (cache. solver_args, :dt , nothing )
6969 maxit = get (cache. solver_args, :maxiters , nothing )
70- differential_vars = get (cache. solver_args, :differential_vars , nothing )
7170 u0 = copy (cache. u0)
7271 p = cache. p # Properly handle NullParameters
7372
7473 if cache. opt isa ODEOptimizer
7574 return solve_ode (cache, dt, maxit, u0, p)
7675 else
7776 if cache. opt. solver isa SciMLBase. AbstractDAEAlgorithm
78- return solve_dae_implicit (cache, dt, maxit, u0, p, differential_vars )
77+ return solve_dae_implicit (cache, dt, maxit, u0, p)
7978 else
8079 return solve_dae_mass_matrix (cache, dt, maxit, u0, p)
8180 end
@@ -139,43 +138,39 @@ end
139138
140139function solve_dae_mass_matrix (cache, dt, maxit, u0, p)
141140 if cache. f. cons === nothing
142- return solve_ode (cache, dt, maxit, u0, p )
141+ error ( " DAEOptimizer requires constraints. Please provide a function with `cons` defined. " )
143142 end
144- x= u0
145- cons_vals = cache. f. cons (x, p)
146143 n = length (u0)
147- m = length (cons_vals)
148- u0_extended = vcat (u0, zeros (m))
149- M = Diagonal (ones (n + m))
150-
144+ m = length (cache. ucons)
151145
146+ if m > n
147+ error (" DAEOptimizer with mass matrix method requires the number of constraints to be less than or equal to the number of variables." )
148+ end
149+ M = Diagonal ([ones (n- m); zeros (m)])
152150 function f_mass! (du, u, p_, t)
153- x = @view u[1 : n]
154- λ = @view u[n+ 1 : end ]
155- grad_f = similar (x)
156- if cache. f. grad != = nothing
157- cache. f. grad (grad_f, x, p_)
158- else
159- grad_f .= ForwardDiff. gradient (z -> cache. f. f (z, p_), x)
160- end
161- J = Matrix {eltype(x)} (undef, m, n)
162- cache. f. cons_j != = nothing && cache. f. cons_j (J, x)
163-
164- @. du[1 : n] = - grad_f - (J' * λ)
165- consv = cache. f. cons (x, p_)
166- @. du[n+ 1 : end ] = consv
151+ cache. f. grad (du, u, p)
152+ @. du = - du
153+ consout = @view du[(n- m)+ 1 : end ]
154+ cache. f. cons (consout, u)
167155 return nothing
168156 end
169157
170- if m == 0
171- optf = ODEFunction (f_mass!)
172- prob = ODEProblem (optf, u0, (0.0 , 1.0 ), p)
173- return solve (prob, cache. opt. solver; dt= dt, maxiters= maxit)
174- end
175-
176- ss_prob = SteadyStateProblem (ODEFunction (f_mass!, mass_matrix = M), u0_extended, p)
158+ ss_prob = SteadyStateProblem (ODEFunction (f_mass!, mass_matrix = M), u0, p)
177159
178- solve_kwargs = setup_progress_callback (cache, Dict ())
160+ if cache. callback != = Optimization. DEFAULT_CALLBACK
161+ condition = (u, t, integrator) -> true
162+ affect! = (integrator) -> begin
163+ u_opt = integrator. u isa AbstractArray ? integrator. u : integrator. u. u
164+ l = cache. f (integrator. u, integrator. p)
165+ cache. callback (integrator. u, l)
166+ end
167+ cb = DiscreteCallback (condition, affect!)
168+ solve_kwargs = Dict {Symbol, Any} (:callback => cb)
169+ else
170+ solve_kwargs = Dict {Symbol, Any} ()
171+ end
172+
173+ solve_kwargs[:progress ] = cache. progress
179174 if maxit != = nothing ; solve_kwargs[:maxiters ] = maxit; end
180175 if dt != = nothing ; solve_kwargs[:dt ] = dt; end
181176
@@ -189,61 +184,48 @@ function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
189184 retcode = sol. retcode)
190185end
191186
192-
193- function solve_dae_implicit (cache, dt, maxit, u0, p, differential_vars)
187+ function solve_dae_implicit (cache, dt, maxit, u0, p)
194188 if cache. f. cons === nothing
195- return solve_ode (cache, dt, maxit, u0, p )
189+ error ( " DAEOptimizer requires constraints. Please provide a function with `cons` defined. " )
196190 end
197- x= u0
198- cons_vals = cache. f. cons (x, p)
191+
199192 n = length (u0)
200- m = length (cons_vals)
201- u0_ext = vcat (u0, zeros (m))
202- du0_ext = zeros (n + m)
193+ m = length (cache. ucons)
203194
204- if differential_vars === nothing
205- differential_vars = vcat (fill (true , n), fill (false , m))
206- else
207- if length (differential_vars) == n
208- differential_vars = vcat (differential_vars, fill (false , m))
209- elseif length (differential_vars) == n + m
210- # use as is
211- else
212- error (" differential_vars length must be number of variables ($n ) or extended size ($(n+ m) )" )
213- end
195+ if m > n
196+ error (" DAEOptimizer with mass matrix method requires the number of constraints to be less than or equal to the number of variables." )
214197 end
215198
216199 function dae_residual! (res, du, u, p_, t)
217- x = @view u[1 : n]
218- λ = @view u[n+ 1 : end ]
219- du_x = @view du[1 : n]
220- grad_f = similar (x)
221- cache. f. grad (grad_f, x, p_)
222- J = zeros (m, n)
223- cache. f. cons_j != = nothing && cache. f. cons_j (J, x)
224-
225- @. res[1 : n] = du_x + grad_f + J' * λ
226- consv = cache. f. cons (x, p_)
227- @. res[n+ 1 : end ] = consv
200+ cache. f. grad (res, u, p)
201+ @. res = du- res
202+ consout = @view res[(n- m)+ 1 : end ]
203+ cache. f. cons (consout, u)
228204 return nothing
229205 end
230206
231- if m == 0
232- optf = ODEFunction (dae_residual!, differential_vars = differential_vars)
233- prob = ODEProblem (optf, du0_ext, (0.0 , 1.0 ), p)
234- return solve (prob, HighOrderDescent (); dt= dt, maxiters= maxit)
235- end
236-
237207 tspan = (0.0 , 10.0 )
238- prob = DAEProblem (dae_residual!, du0_ext, u0_ext, tspan, p;
239- differential_vars = differential_vars)
208+ du0 = zero (u0)
209+ prob = DAEProblem (dae_residual!, du0, u0, tspan, p)
210+
211+ if cache. callback != = Optimization. DEFAULT_CALLBACK
212+ condition = (u, t, integrator) -> true
213+ affect! = (integrator) -> begin
214+ u_opt = integrator. u isa AbstractArray ? integrator. u : integrator. u. u
215+ l = cache. f (integrator. u, integrator. p)
216+ cache. callback (integrator. u, l)
217+ end
218+ cb = DiscreteCallback (condition, affect!)
219+ solve_kwargs = Dict {Symbol, Any} (:callback => cb)
220+ else
221+ solve_kwargs = Dict {Symbol, Any} ()
222+ end
223+
224+ solve_kwargs[:progress ] = cache. progress
240225
241- solve_kwargs = setup_progress_callback (cache, Dict ())
242226 if maxit != = nothing ; solve_kwargs[:maxiters ] = maxit; end
243227 if dt != = nothing ; solve_kwargs[:dt ] = dt; end
244- if hasfield (typeof (cache. opt. solver), :initializealg )
245- solve_kwargs[:initializealg ] = BrownFullBasicInit ()
246- end
228+ solve_kwargs[:initializealg ] = ShampineCollocationInit ()
247229
248230 sol = solve (prob, cache. opt. solver; solve_kwargs... )
249231 u_ext = sol. u
0 commit comments