@@ -42,14 +42,14 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
4242 end
4343end
4444
45- function (M:: MXLinearInterpolation )(τ)
45+ function (M:: MXLinearInterpolation )(τ)
4646 nt = (τ - M. t[1 ]) / M. dt
4747 i = 1 + floor (Int, nt)
4848 Δ = nt - i + 1
4949
5050 (i > length (M. t) || i < 1 ) && error (" Cannot extrapolate past the tspan." )
5151 if i < length (M. t)
52- M. u[:, i] + Δ* (M. u[:, i + 1 ] - M. u[:, i])
52+ M. u[:, i] + Δ * (M. u[:, i + 1 ] - M. u[:, i])
5353 else
5454 M. u[:, i]
5555 end
@@ -74,7 +74,7 @@ The constraints are:
7474function MTK. CasADiDynamicOptProblem (sys:: ODESystem , u0map, tspan, pmap;
7575 dt = nothing ,
7676 steps = nothing ,
77- guesses = Dict (), kwargs... )
77+ guesses = Dict (), kwargs... )
7878 MTK. warn_overdetermined (sys, u0map)
7979 _u0map = has_alg_eqs (sys) ? u0map : merge (Dict (u0map), Dict (guesses))
8080 f, u0, p = MTK. process_SciMLProblem (ODEInputFunction, sys, _u0map, pmap;
@@ -104,21 +104,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
104104 subject_to! (opti, tₛ >= lo)
105105 subject_to! (opti, tₛ >= hi)
106106 end
107- pmap[te_sym] = tₛ
107+ pmap[te_sym] = tₛ
108108 tsteps = LinRange (0 , 1 , steps)
109109 else
110110 tₛ = MX (1 )
111111 tsteps = LinRange (tspan[1 ], tspan[2 ], steps)
112112 end
113-
113+
114114 U = CasADi. variable! (opti, length (states), steps)
115115 V = CasADi. variable! (opti, length (ctrls), steps)
116116 set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
117117 c0 = MTK. value .([pmap[c] for c in ctrls])
118118 ! isempty (c0) && set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
119119
120- U_interp = MXLinearInterpolation (U, tsteps, tsteps[2 ]- tsteps[1 ])
121- V_interp = MXLinearInterpolation (V, tsteps, tsteps[2 ]- tsteps[1 ])
120+ U_interp = MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
121+ V_interp = MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
122122 for (i, ct) in enumerate (ctrls)
123123 pmap[ct] = V[i, :]
124124 end
@@ -185,8 +185,8 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
185185 x = MTK. operation (st)
186186 t = only (MTK. arguments (st))
187187 MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
188- if haskey (stidxmap, x (iv))
189- idx = stidxmap[x (iv)]
188+ if haskey (stidxmap, x (iv))
189+ idx = stidxmap[x (iv)]
190190 cv = U
191191 else
192192 idx = ctidxmap[x (iv)]
@@ -196,11 +196,11 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
196196 end
197197
198198 if cons isa Equation
199- subject_to! (opti, cons. lhs - cons. rhs== 0 )
199+ subject_to! (opti, cons. lhs - cons. rhs == 0 )
200200 elseif cons. relational_op === Symbolics. geq
201- subject_to! (opti, cons. lhs - cons. rhs≥ 0 )
201+ subject_to! (opti, cons. lhs - cons. rhs ≥ 0 )
202202 else
203- subject_to! (opti, cons. lhs - cons. rhs≤ 0 )
203+ subject_to! (opti, cons. lhs - cons. rhs ≤ 0 )
204204 end
205205 end
206206end
@@ -227,8 +227,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
227227 x = operation (st)
228228 t = only (arguments (st))
229229 MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
230- if haskey (stidxmap, x (iv))
231- idx = stidxmap[x (iv)]
230+ if haskey (stidxmap, x (iv))
231+ idx = stidxmap[x (iv)]
232232 cv = U
233233 else
234234 idx = ctidxmap[x (iv)]
@@ -244,7 +244,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
244244 op = MTK. operation (int)
245245 arg = only (arguments (MTK. value (int)))
246246 lo, hi = (op. domain. domain. left, op. domain. domain. right)
247- ! isequal ((lo, hi), tspan) && error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." )
247+ ! isequal ((lo, hi), tspan) &&
248+ error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." )
248249 # Approximate integral as sum.
249250 intmap[int] = dt * tₛ * sum (arg)
250251 end
@@ -253,7 +254,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
253254 minimize! (opti, MX (MTK. value (consolidate (jcosts))))
254255end
255256
256- function substitute_casadi_vars (model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
257+ function substitute_casadi_vars (
258+ model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
257259 @unpack opti, U, V, tₛ = model
258260 iv = MTK. get_iv (sys)
259261 sts = unknowns (sys)
@@ -281,44 +283,44 @@ end
281283
282284function add_solve_constraints (prob, tableau)
283285 @unpack A, α, c = tableau
284- @unpack model, f, p = prob
286+ @unpack model, f, p = prob
285287 @unpack opti, U, V, tₛ = model
286288 solver_opti = copy (opti)
287289
288- tsteps = U. t
290+ tsteps = U. t
289291 dt = tsteps[2 ] - tsteps[1 ]
290292
291293 nᵤ = size (U. u, 1 )
292294 nᵥ = size (V. u, 1 )
293295
294296 if MTK. is_explicit (tableau)
295297 K = MX[]
296- for k in 1 : length (tsteps)- 1
298+ for k in 1 : ( length (tsteps) - 1 )
297299 τ = tsteps[k]
298300 for (i, h) in enumerate (c)
299301 ΔU = sum ([A[i, j] * K[j] for j in 1 : (i - 1 )], init = MX (zeros (nᵤ)))
300- Uₙ = U. u[:, k] + ΔU* dt
302+ Uₙ = U. u[:, k] + ΔU * dt
301303 Vₙ = V. u[:, k]
302304 Kₙ = tₛ * f (Uₙ, Vₙ, p, τ + h * dt) # scale the time
303305 push! (K, Kₙ)
304306 end
305307 ΔU = dt * sum ([α[i] * K[i] for i in 1 : length (α)])
306- subject_to! (solver_opti, U. u[:, k] + ΔU == U. u[:, k+ 1 ])
308+ subject_to! (solver_opti, U. u[:, k] + ΔU == U. u[:, k + 1 ])
307309 empty! (K)
308310 end
309311 else
310- for k in 1 : length (tsteps)- 1
312+ for k in 1 : ( length (tsteps) - 1 )
311313 τ = tsteps[k]
312314 Kᵢ = variable! (solver_opti, nᵤ, length (α))
313315 ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method
314316 for (i, h) in enumerate (c)
315- ΔU = ΔUs[i,:]'
316- Uₙ = U. u[:,k] + ΔU* dt
317- Vₙ = V. u[:,k]
318- subject_to! (solver_opti, Kᵢ[:,i] == tₛ * f (Uₙ, Vₙ, p, τ + h* dt))
317+ ΔU = ΔUs[i, :]'
318+ Uₙ = U. u[:, k] + ΔU * dt
319+ Vₙ = V. u[:, k]
320+ subject_to! (solver_opti, Kᵢ[:, i] == tₛ * f (Uₙ, Vₙ, p, τ + h * dt))
319321 end
320- ΔU_tot = dt* (Kᵢ* α)
321- subject_to! (solver_opti, U. u[:, k] + ΔU_tot == U. u[:,k + 1 ])
322+ ΔU_tot = dt * (Kᵢ * α)
323+ subject_to! (solver_opti, U. u[:, k] + ΔU_tot == U. u[:, k + 1 ])
322324 end
323325 end
324326 solver_opti
331333
332334NOTE: the solver should be passed in as a string to CasADi. "ipopt"
333335"""
334- function DiffEqBase. solve (prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} = " ipopt" , tableau_getter = MTK. constructDefault; plugin_options:: Dict = Dict (), solver_options:: Dict = Dict (), silent = false )
336+ function DiffEqBase. solve (
337+ prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} = " ipopt" ,
338+ tableau_getter = MTK. constructDefault; plugin_options:: Dict = Dict (),
339+ solver_options:: Dict = Dict (), silent = false )
335340 @unpack model, u0, p, tspan, f = prob
336341 tableau = tableau_getter ()
337342 @unpack opti, U, V, tₛ = model
@@ -366,7 +371,8 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
366371 end
367372
368373 if failed
369- ode_sol = SciMLBase. solution_new_retcode (ode_sol, SciMLBase. ReturnCode. ConvergenceFailure)
374+ ode_sol = SciMLBase. solution_new_retcode (
375+ ode_sol, SciMLBase. ReturnCode. ConvergenceFailure)
370376 ! isnothing (input_sol) && (input_sol = SciMLBase. solution_new_retcode (
371377 input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
372378 end
0 commit comments