@@ -42,14 +42,14 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
4242    end 
4343end 
4444
45- function  (M:: MXLinearInterpolation )(τ)  
45+ function  (M:: MXLinearInterpolation )(τ)
4646    nt =  (τ -  M. t[1 ]) /  M. dt
4747    i =  1  +  floor (Int, nt)
4848    Δ =  nt -  i +  1 
4949
5050    (i >  length (M. t) ||  i <  1 ) &&  error (" Cannot extrapolate past the tspan." 
5151    if  i <  length (M. t)
52-         M. u[:, i] +  Δ* (M. u[:, i +  1 ] -  M. u[:, i])
52+         M. u[:, i] +  Δ  *   (M. u[:, i +  1 ] -  M. u[:, i])
5353    else 
5454        M. u[:, i]
5555    end 
@@ -74,7 +74,7 @@ The constraints are:
7474function  MTK. CasADiDynamicOptProblem (sys:: ODESystem , u0map, tspan, pmap;
7575        dt =  nothing ,
7676        steps =  nothing ,
77-         guesses =  Dict (), kwargs... )  
77+         guesses =  Dict (), kwargs... )
7878    MTK. warn_overdetermined (sys, u0map)
7979    _u0map =  has_alg_eqs (sys) ?  u0map :  merge (Dict (u0map), Dict (guesses))
8080    f, u0, p =  MTK. process_SciMLProblem (ODEInputFunction, sys, _u0map, pmap;
@@ -104,21 +104,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
104104            subject_to! (opti, tₛ >=  lo)
105105            subject_to! (opti, tₛ >=  hi)
106106        end 
107-         pmap[te_sym] =  tₛ  
107+         pmap[te_sym] =  tₛ
108108        tsteps =  LinRange (0 , 1 , steps)
109109    else 
110110        tₛ =  MX (1 )
111111        tsteps =  LinRange (tspan[1 ], tspan[2 ], steps)
112112    end 
113-      
113+ 
114114    U =  CasADi. variable! (opti, length (states), steps)
115115    V =  CasADi. variable! (opti, length (ctrls), steps)
116116    set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
117117    c0 =  MTK. value .([pmap[c] for  c in  ctrls])
118118    ! isempty (c0) &&  set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
119119
120-     U_interp =  MXLinearInterpolation (U, tsteps, tsteps[2 ]- tsteps[1 ])
121-     V_interp =  MXLinearInterpolation (V, tsteps, tsteps[2 ]- tsteps[1 ])
120+     U_interp =  MXLinearInterpolation (U, tsteps, tsteps[2 ]  -   tsteps[1 ])
121+     V_interp =  MXLinearInterpolation (V, tsteps, tsteps[2 ]  -   tsteps[1 ])
122122    for  (i, ct) in  enumerate (ctrls)
123123        pmap[ct] =  V[i, :]
124124    end 
@@ -185,8 +185,8 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
185185            x =  MTK. operation (st)
186186            t =  only (MTK. arguments (st))
187187            MTK. symbolic_type (t) ===  MTK. NotSymbolic () ||  continue 
188-             if  haskey (stidxmap, x (iv))  
189-                 idx =  stidxmap[x (iv)]  
188+             if  haskey (stidxmap, x (iv))
189+                 idx =  stidxmap[x (iv)]
190190                cv =  U
191191            else 
192192                idx =  ctidxmap[x (iv)]
@@ -196,11 +196,11 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
196196        end 
197197
198198        if  cons isa  Equation
199-             subject_to! (opti, cons. lhs -  cons. rhs== 0 )
199+             subject_to! (opti, cons. lhs -  cons. rhs  ==   0 )
200200        elseif  cons. relational_op ===  Symbolics. geq
201-             subject_to! (opti, cons. lhs -  cons. rhs≥ 0 )
201+             subject_to! (opti, cons. lhs -  cons. rhs  ≥   0 )
202202        else 
203-             subject_to! (opti, cons. lhs -  cons. rhs≤ 0 )
203+             subject_to! (opti, cons. lhs -  cons. rhs  ≤   0 )
204204        end 
205205    end 
206206end 
@@ -227,8 +227,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
227227            x =  operation (st)
228228            t =  only (arguments (st))
229229            MTK. symbolic_type (t) ===  MTK. NotSymbolic () ||  continue 
230-             if  haskey (stidxmap, x (iv))  
231-                 idx =  stidxmap[x (iv)]  
230+             if  haskey (stidxmap, x (iv))
231+                 idx =  stidxmap[x (iv)]
232232                cv =  U
233233            else 
234234                idx =  ctidxmap[x (iv)]
@@ -244,7 +244,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
244244        op =  MTK. operation (int)
245245        arg =  only (arguments (MTK. value (int)))
246246        lo, hi =  (op. domain. domain. left, op. domain. domain. right)
247-         ! isequal ((lo, hi), tspan) &&  error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." 
247+         ! isequal ((lo, hi), tspan) && 
248+             error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." 
248249        #  Approximate integral as sum.
249250        intmap[int] =  dt *  tₛ *  sum (arg)
250251    end 
@@ -253,7 +254,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
253254    minimize! (opti, MX (MTK. value (consolidate (jcosts))))
254255end 
255256
256- function  substitute_casadi_vars (model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict  =  Dict (), is_free_t)
257+ function  substitute_casadi_vars (
258+         model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict  =  Dict (), is_free_t)
257259    @unpack  opti, U, V, tₛ =  model
258260    iv =  MTK. get_iv (sys)
259261    sts =  unknowns (sys)
@@ -281,44 +283,44 @@ end
281283
282284function  add_solve_constraints (prob, tableau)
283285    @unpack  A, α, c =  tableau
284-     @unpack  model, f, p =  prob  
286+     @unpack  model, f, p =  prob
285287    @unpack  opti, U, V, tₛ =  model
286288    solver_opti =  copy (opti)
287289
288-     tsteps =  U. t  
290+     tsteps =  U. t
289291    dt =  tsteps[2 ] -  tsteps[1 ]
290292
291293    nᵤ =  size (U. u, 1 )
292294    nᵥ =  size (V. u, 1 )
293295
294296    if  MTK. is_explicit (tableau)
295297        K =  MX[]
296-         for  k in  1 : length (tsteps)- 1 
298+         for  k in  1 : ( length (tsteps)  -   1 ) 
297299            τ =  tsteps[k]
298300            for  (i, h) in  enumerate (c)
299301                ΔU =  sum ([A[i, j] *  K[j] for  j in  1 : (i -  1 )], init =  MX (zeros (nᵤ)))
300-                 Uₙ =  U. u[:, k] +  ΔU* dt
302+                 Uₙ =  U. u[:, k] +  ΔU  *   dt
301303                Vₙ =  V. u[:, k]
302304                Kₙ =  tₛ *  f (Uₙ, Vₙ, p, τ +  h *  dt) #  scale the time
303305                push! (K, Kₙ)
304306            end 
305307            ΔU =  dt *  sum ([α[i] *  K[i] for  i in  1 : length (α)])
306-             subject_to! (solver_opti, U. u[:, k] +  ΔU ==  U. u[:, k+ 1 ])
308+             subject_to! (solver_opti, U. u[:, k] +  ΔU ==  U. u[:, k  +   1 ])
307309            empty! (K)
308310        end 
309311    else 
310-         for  k in  1 : length (tsteps)- 1 
312+         for  k in  1 : ( length (tsteps)  -   1 ) 
311313            τ =  tsteps[k]
312314            Kᵢ =  variable! (solver_opti, nᵤ, length (α))
313315            ΔUs =  A *  Kᵢ'  #  the stepsize at each stage of the implicit method
314316            for  (i, h) in  enumerate (c)
315-                 ΔU =  ΔUs[i,:]' 
316-                 Uₙ =  U. u[:,k] +  ΔU* dt
317-                 Vₙ =  V. u[:,k]
318-                 subject_to! (solver_opti, Kᵢ[:,i] ==  tₛ *  f (Uₙ, Vₙ, p, τ +  h* dt))
317+                 ΔU =  ΔUs[i,  :]' 
318+                 Uₙ =  U. u[:,  k] +  ΔU  *   dt
319+                 Vₙ =  V. u[:,  k]
320+                 subject_to! (solver_opti, Kᵢ[:,  i] ==  tₛ *  f (Uₙ, Vₙ, p, τ +  h  *   dt))
319321            end 
320-             ΔU_tot =  dt* (Kᵢ* α)
321-             subject_to! (solver_opti, U. u[:, k] +  ΔU_tot ==  U. u[:,k + 1 ])
322+             ΔU_tot =  dt  *   (Kᵢ  *   α)
323+             subject_to! (solver_opti, U. u[:, k] +  ΔU_tot ==  U. u[:, k  +   1 ])
322324        end 
323325    end 
324326    solver_opti
331333
332334NOTE: the solver should be passed in as a string to CasADi. "ipopt" 
333335""" 
334- function  DiffEqBase. solve (prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol}  =  " ipopt" =  MTK. constructDefault; plugin_options:: Dict  =  Dict (), solver_options:: Dict  =  Dict (), silent =  false )
336+ function  DiffEqBase. solve (
337+         prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol}  =  " ipopt" 
338+         tableau_getter =  MTK. constructDefault; plugin_options:: Dict  =  Dict (),
339+         solver_options:: Dict  =  Dict (), silent =  false )
335340    @unpack  model, u0, p, tspan, f =  prob
336341    tableau =  tableau_getter ()
337342    @unpack  opti, U, V, tₛ =  model
@@ -366,7 +371,8 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
366371    end 
367372
368373    if  failed
369-         ode_sol =  SciMLBase. solution_new_retcode (ode_sol, SciMLBase. ReturnCode. ConvergenceFailure)
374+         ode_sol =  SciMLBase. solution_new_retcode (
375+             ode_sol, SciMLBase. ReturnCode. ConvergenceFailure)
370376        ! isnothing (input_sol) &&  (input_sol =  SciMLBase. solution_new_retcode (
371377            input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
372378    end 
0 commit comments