@@ -3,6 +3,7 @@ using ModelingToolkit
33using CasADi
44using DiffEqDevTools, DiffEqBase
55using DataInterpolations
6+ using UnPack
67const MTK = MOdelingToolkit
78
89struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} < :
2223
2324struct CasADiModel
2425 opti:: Opti
25- U:: MX
26- V:: MX
27- end
28-
29- struct TimedMX
26+ U:: AbstractInterpolation
27+ V:: AbstractInterpolation
28+ tₛ:: Union{Number, MX}
3029end
3130
3231function MTK. CasADiDynamicOptProblem (sys:: ODESystem , u0map, tspan, pmap;
3332 dt = nothing ,
34- steps = nothing ,
33+ steps = nothing ,
34+ interpolation_method:: AbstractInterpolation = LinearInterpolation,
3535 guesses = Dict (), kwargs... )
3636 MTK. warn_overdetermined (sys, u0map)
3737 _u0map = has_alg_eqs (sys) ? u0map : merge (Dict (u0map), Dict (guesses))
@@ -43,73 +43,228 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
4343 model = init_model ()
4444end
4545
46- function init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t)
46+ function init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t = false , interp_type :: AbstractInterpolation )
4747 ctrls = MTK. unbound_inputs (sys)
4848 states = unknowns (sys)
49- model = CasADi. Opti ()
49+ opti = CasADi. Opti ()
50+
51+ if is_free_t
52+ (ts_sym, te_sym) = tspan
53+ MTK. symbolic_type (ts_sym) != = MTK. NotSymbolic () &&
54+ error (" Free initial time problems are not currently supported." )
55+ tₛ = variable! (opti)
56+ tsteps = LinRange (0 , 1 , steps)
57+ else
58+ tₛ = 1
59+ tsteps = LinRange (tspan[1 ], tspan[2 ], steps)
60+ end
5061
51- U = CasADi. variable! (model, length (states), steps)
52- V = CasADi. variable! (model, length (ctrls), steps)
62+ U = CasADi. variable! (opti, length (states), steps)
63+ V = CasADi. variable! (opti, length (ctrls), steps)
64+
65+ U_interp = interp_type (Matrix (U), tsteps)
66+ V_interp = interp_type (Matrix (V), tsteps)
67+
68+ CasADiModel (opti, U_interp, V_interp, tₛ)
5369end
5470
55- function add_initial_constraints! ()
56-
71+ function set_casadi_bounds! (model, sys, pmap)
72+ @unpack opti, U, V = model
73+ for (i, u) in enumerate (unknowns (sys))
74+ if MTK. hasbounds (u)
75+ lo, hi = MTK. getbounds (u)
76+ subject_to! (opti, lo <= U[i, :] <= hi)
77+ end
78+ end
79+ for (i, v) in enumerate (MTK. unbound_inputs (sys))
80+ if MTK. hasbounds (v)
81+ lo, hi = MTK. getbounds (v)
82+ subject_to! (opti, lo <= V[i, :] <= hi)
83+ end
84+ end
85+ end
86+
87+ function add_initial_constraints! (model:: CasADiModel , u0, u0_idxs, ts)
88+ @unpack opti, U = model
89+ for i in u0_idxs
90+ subject_to! (opti, U. u[i, 1 ] == u0[i])
91+ end
5792end
5893
5994function add_user_constraints! (model:: CasADiModel , sys, pmap; is_free_t = false )
60-
95+ @unpack opti, U, V, tₛ = model
96+
97+ iv = get_iv (sys)
98+ conssys = MTK. get_constraintsystem (sys)
99+ jconstraints = isnothing (conssys) ? nothing : MTK. get_constraints (conssys)
100+ (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
101+
102+ stidxmap = Dict ([v => i for (i, v) in enumerate (sts)])
103+ pidxmap = Dict ([v => i for (i, v) in enumerate (ps)])
104+ cons_unknowns = map (MTK. default_toterm, unknowns (conssys))
105+ for st in cons_unknowns
106+ x = operation (st)
107+ t = only (argments (st))
108+ idx = stidxmap[x (iv)]
109+
110+ jconstraints = map (c -> Symbolics. substitute (c, Dict (x (t) => U (t)[idx])), jconstraints)
111+ end
112+ jconstraints = substitute_casadi_vars (model, sys, pmap, jconstraints)
113+
114+ for (i, cons) in enumerate (jconstraints)
115+ if cons isa Equation
116+ subject_to! (opti, cons. lhs - cons. rhs== 0 )
117+ elseif cons. relational_op === Symbolics. geq
118+ subject_to! (model, cons. lhs - cons. rhs≥ 0 )
119+ else
120+ subject_to! (model, cons. lhs - cons. rhs≤ 0 )
121+ end
122+ end
61123end
62124
63- function add_cost_function! (model)
125+ function add_cost_function! (model:: CasADiModel , sys, tspan, pmap)
126+ @unpack opti, U, V, tₛ = model
127+ jcosts = MTK. get_costs (sys)
128+ consolidate = MTK. get_consolidate (sys)
129+
130+ if isnothing (jcosts) || isempty (jcosts)
131+ minimize! (opti, 0 )
132+ return
133+ end
134+ stidxmap = Dict ([v => i for (i, v) in enumerate (sts)])
135+ pidxmap = Dict ([v => i for (i, v) in enumerate (ps)])
64136
137+ for i in 1 : length (jcosts)
138+ vars = vars (jcosts[i])
139+ for st in vars
140+ x = operation (st)
141+ t = only (arguments (st))
142+ t isa Union{Num, MTK. Symbolic} && continue
143+ idx = stidxmap[x (iv)]
144+ jcosts[i] = Symbolics. substitute (jcosts[i], Dict (x (t) => U (t)[idx]))
145+ end
146+ end
147+ jcosts = substitute_casadi_vars (model:: CasADiModel , sys, pmap, jcosts; auxmap)
148+ jcosts = map (
149+ c -> Symbolics. substitute (c, MTK.∫ () => Symbolics. Integral (iv in tspan)), jcosts)
150+
151+ dt = U. t[2 ] - U. t[1 ]
152+ intmap = Dict ()
153+ for int in MTK. collect_applied_operators (jcosts, Symbolics. Integral)
154+ op = MTK. operation (int)
155+ arg = only (arguments (MTK. value (int)))
156+ lo, hi = (op. domain. domain. left, op. domain. domain. right)
157+ (lo, hi) != = tspan && error (" Non-whole interval bounds for integrals are not currently supported." )
158+ intmap[int] = dt * tₛ * sum (arg)
159+ end
160+ jcosts = map (c -> Symbolics. substitute (c, intmap), jcosts)
161+ minimize! (opti, consolidate (jcosts))
162+ end
163+
164+ function substitute_casadi_vars (model:: CasADiModel , sys, pmap, exprs; auxmap = Dict ())
165+ @unpack opti, U, V = model
166+ iv = MTK. get_iv (sys)
167+ sts = unknowns (sys)
168+ cts = MTK. unbound_inputs (sys)
169+
170+ x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
171+ c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
172+
173+ exprs = map (c -> Symbolics. fixpoint_sub (c, auxmap), exprs)
174+ exprs = map (c -> Symbolics. fixpoint_sub (c, Dict (pmap)), exprs)
175+
176+ # for variables like x(t)
177+ whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
178+ [v => V. u[i, :] for (i, v) in enumerate (cts)]])
179+ exprs = map (c -> Symbolics. fixpoint_sub (c, whole_interval_map), exprs)
180+ exprs
65181end
66182
67183function add_solve_constraints! (prob, tableau; is_free_t)
68- A = tableau. A
69- α = tableau. α
70- c = tableau. c
71- model = prob. model
72- f = prob. f
73- p = prob. p
184+ @unpack A, α, c = tableau
185+ @unpack model, f, p = prob
186+ @unpack opti, U, V, tₛ = model
74187
75- opti = model. opti
76- t = model[:t ]
77- tsteps = supports (t)
78- tmax = tsteps[end ]
79- pop! (tsteps)
80- tₛ = is_free_t ? model[:tf ] : 1
188+ tsteps = U. t
81189 dt = tsteps[2 ] - tsteps[1 ]
82190
83- U = model. U
84- V = model. V
85191 nᵤ = length (U)
86192 nᵥ = length (V)
87193
88194 if is_explicit (tableau)
89195 K = Any[]
90- for τ in tsteps
196+ for k in 1 : length ( tsteps) - 1
91197 for (i, h) in enumerate (c)
92198 ΔU = sum ([A[i, j] * K[j] for j in 1 : (i - 1 )], init = zeros (nᵤ))
93- Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1 : nᵤ]
94- Vₙ = [V[i](τ) for i in 1 : nᵥ ]
199+ Uₙ = U . u[:, k] + ΔU* dt
200+ Vₙ = V . u[:, k ]
95201 Kₙ = tₛ * f (Uₙ, Vₙ, p, τ + h * dt) # scale the time
96202 push! (K, Kₙ)
97203 end
98204 ΔU = dt * sum ([α[i] * K[i] for i in 1 : length (α)])
99- subject_to! (model, U[n](τ) + ΔU[n]== U[n](τ + dt))
100- empty! (K)
205+ subject_to! (opti, U. u[:, k] + ΔU == U. u[:, k+ 1 ])
101206 end
102207 else
208+ ΔU_tot = dt * (K' * α)
209+ for k in 1 : length (tsteps)- 1
210+ Kᵢ = variable! (opti, length (α), nᵤ)
211+ ΔUs = A * Kᵢ # the stepsize at each stage of the implicit method
212+ for (i, h) in enumerate (c)
213+ ΔU = @view ΔUs[i, :]
214+ Uₙ = U. u[:,k] + ΔU
215+ Vₙ = V. u[:,k]
216+ subject_to! (opti, K[i,:] == tₛ * f (Uₙ, Vₙ, p, τ + h* dt))
217+ end
218+ ΔU_tot = dt* (Kᵢ' * α)
219+ subject_to! (opti, U. u[:, k] + ΔU_tot == U. u[:,k+ 1 ])
220+ end
103221 end
104222end
105223
106- function DiffEqBase. solve (prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} , ode_solver:: Union{String, Symbol} ; silent = false )
224+ is_explicit (tableau) = tableau isa DiffEqDevTools. ExplicitRKTableau
225+
226+ """
227+ solve(prob::CasADiDynamicOptProblem, casadi_solver, ode_solver; plugin_options, solver_options)
228+
229+ `plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
230+ """
231+ function DiffEqBase. solve (prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} , ode_solver:: Union{String, Symbol} ; plugin_options:: Dict = Dict (), solver_options:: Dict = Dict (), silent = false )
107232 model = prob. model
108233 opti = model. opti
109234
110- solver! (opti, solver)
111- sol = solve (opti)
112- DynamicOptSolution (model, sol, input_sol)
235+ solver! (opti, solver, plugin_options, solver_options)
236+ add_casadi_solve_constraints! (prob, tableau)
237+ solver! (cmodel, " $solver " , plugin_options, solver_options)
238+
239+ failed = false
240+ try
241+ sol = solve (opti)
242+ value_getter = x -> CasADi. value (sol, x)
243+ catch ErrorException
244+ value_getter = x -> CasADi. debug_value (opti, x)
245+ failed = true
246+ continue
247+ end
248+
249+ ts = value_getter (tₛ) * U. t
250+ U_vals = value_getter (U)
251+ U_vals = [[U_vals[i][j] for i in 1 : length (U_vals)] for j in 1 : length (ts)]
252+ sol = DiffEqBase. build_solution (prob, ode_solver, ts, U_vals)
253+
254+ input_sol = nothing
255+ if ! isempty (V)
256+ V_vals = value_getter (V)
257+ V_vals = [[V_vals[i][j] for i in 1 : length (V_vals)] for j in 1 : length (ts)]
258+ input_sol = DiffEqBase. build_solution (prob, ode_solver, ts, V_vals)
259+ end
260+
261+ if failed
262+ sol = SciMLBase. solution_new_retcode (sol, SciMLBase. ReturnCode. ConvergenceFailure)
263+ ! isnothing (input_sol) && (input_sol = SciMLBase. solution_new_retcode (
264+ input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
265+ end
266+
267+ DynamicOptSolution (cmodel, sol, input_sol)
113268end
114269
115270end
0 commit comments