@@ -18,12 +18,17 @@ struct MXLinearInterpolation
1818 dt:: Float64
1919end
2020
21- struct CasADiModel
22- opti :: Opti
21+ mutable struct CasADiModel
22+ model :: Opti
2323 U:: MXLinearInterpolation
2424 V:: MXLinearInterpolation
2525 tₛ:: MX
2626 is_free_final:: Bool
27+ solver_opti:: Union{Nothing, Opti}
28+
29+ function CasADiModel (opti, U, V, tₛ, is_free_final, solver_opti = nothing )
30+ new (opti, U, V, tₛ, is_free_final, solver_opti)
31+ end
2732end
2833
2934struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} < :
@@ -74,24 +79,27 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
7479 dt = nothing ,
7580 steps = nothing ,
7681 guesses = Dict (), kwargs... )
77- process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
82+ MTK . process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
7883end
7984
80- MTK. generate_internal_model (:: Type{CasADiModel} ) = CasADi. opti ()
85+ MTK. generate_internal_model (:: Type{CasADiModel} ) = CasADi. Opti ()
86+ MTK. generate_time_variable! (opti:: Opti , args... ) = nothing
8187
82- function MTK. generate_state_variable (model:: Opti , u0, ns, nt, tsteps)
88+ function MTK. generate_state_variable! (model:: Opti , u0, ns, tsteps)
89+ nt = length (tsteps)
8390 U = CasADi. variable! (model, ns, nt)
84- set_initial! (opti , U, DM (repeat (u0, 1 , steps )))
91+ set_initial! (model , U, DM (repeat (u0, 1 , nt )))
8592 MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
8693end
8794
88- function MTK. generate_input_variable (model:: Opti , c0, nc, nt, tsteps)
95+ function MTK. generate_input_variable! (model:: Opti , c0, nc, tsteps)
96+ nt = length (tsteps)
8997 V = CasADi. variable! (model, nc, nt)
90- ! isempty (c0) && set_initial! (opti , V, DM (repeat (c0, 1 , steps )))
98+ ! isempty (c0) && set_initial! (model , V, DM (repeat (c0, 1 , nt )))
9199 MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
92100end
93101
94- function MTK. generate_timescale (model:: Opti , guess, is_free_t)
102+ function MTK. generate_timescale! (model:: Opti , guess, is_free_t)
95103 if is_free_t
96104 tₛ = variable! (model)
97105 set_initial! (model, tₛ, guess)
@@ -102,78 +110,73 @@ function MTK.generate_timescale(model::Opti, guess, is_free_t)
102110 end
103111end
104112
105- function MTK. add_constraint! (model:: CasADiModel , expr)
106- @unpack opti = model
107- if cons isa Equation
108- subject_to! (opti, expr. lhs - expr. rhs == 0 )
109- elseif cons. relational_op === Symbolics. geq
110- subject_to! (opti, expr. lhs - expr. rhs ≥ 0 )
113+ function MTK. add_constraint! (m:: CasADiModel , expr)
114+ if expr isa Equation
115+ subject_to! (m. model, expr. lhs - expr. rhs == 0 )
116+ elseif expr. relational_op === Symbolics. geq
117+ subject_to! (m. model, expr. lhs - expr. rhs ≥ 0 )
111118 else
112- subject_to! (opti , expr. lhs - expr. rhs ≤ 0 )
119+ subject_to! (m . model , expr. lhs - expr. rhs ≤ 0 )
113120 end
114121end
115- MTK. set_objective! (model :: CasADiModel , expr) = minimize! (model . opti , MX (expr))
122+ MTK. set_objective! (m :: CasADiModel , expr) = minimize! (m . model , MX (expr))
116123
117- function MTK. set_variable_bounds! (model , sys, pmap, tf)
118- @unpack opti , U, V = model
124+ function MTK. set_variable_bounds! (m :: CasADiModel , sys, pmap, tf)
125+ @unpack model , U, tₛ, V = m
119126 for (i, u) in enumerate (unknowns (sys))
120127 if MTK. hasbounds (u)
121128 lo, hi = MTK. getbounds (u)
122- subject_to! (opti , Symbolics. fast_substitute (lo, pmap) <= U. u[i, :])
123- subject_to! (opti , U. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
129+ subject_to! (model , Symbolics. fixpoint_sub (lo, pmap) <= U. u[i, :])
130+ subject_to! (model , U. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
124131 end
125132 end
126133 for (i, v) in enumerate (MTK. unbound_inputs (sys))
127134 if MTK. hasbounds (v)
128135 lo, hi = MTK. getbounds (v)
129- subject_to! (opti , Symbolics. fast_substitute (lo, pmap) <= V. u[i, :])
130- subject_to! (opti , V. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
136+ subject_to! (model , Symbolics. fixpoint_sub (lo, pmap) <= V. u[i, :])
137+ subject_to! (model , V. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
131138 end
132139 end
133140 if MTK. symbolic_type (tf) === MTK. ScalarSymbolic () && hasbounds (tf)
134141 lo, hi = MTK. getbounds (tf)
135- subject_to! (opti, model . tₛ >= lo)
136- subject_to! (opti, model . tₛ <= hi)
142+ subject_to! (model, tₛ >= lo)
143+ subject_to! (model, tₛ <= hi)
137144 end
138145end
139146
140- function MTK. add_initial_constraints! (model :: CasADiModel , u0, u0_idxs)
141- @unpack opti , U = model
147+ function MTK. add_initial_constraints! (m :: CasADiModel , u0, u0_idxs, args ... )
148+ @unpack model , U = m
142149 for i in u0_idxs
143- subject_to! (opti , U. u[i, 1 ] == u0[i])
150+ subject_to! (model , U. u[i, 1 ] == u0[i])
144151 end
145152end
146153
147- function MTK. substitute_model_vars (
148- model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
149- @unpack opti, U, V, tₛ = model
154+ function MTK. substitute_model_vars (m:: CasADiModel , sys, exprs, tspan)
155+ @unpack model, U, V, tₛ = m
150156 iv = MTK. get_iv (sys)
151157 sts = unknowns (sys)
152158 cts = MTK. unbound_inputs (sys)
153-
154159 x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
155160 c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
156-
157- exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
158- exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
159- # tf means different things in different contexts; a [tf] in a cost function
160- # should be tₛ, while a x(tf) should translate to x[1]
161- if is_free_t
162- free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
163- [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
161+ (ti, tf) = tspan
162+ if MTK. is_free_final (m)
163+ _tf = tₛ + ti
164+ exprs = map (c -> Symbolics. fast_substitute (c, Dict (tf => _tf)), exprs)
165+ free_t_map = Dict ([[x (_tf) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
166+ [c (_tf) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
164167 exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
165168 end
166169
167- exprs = substitute_fixed_t_vars (exprs)
168-
169- # for variables like x(t)
170+ exprs = substitute_fixed_t_vars (m, sys, exprs)
170171 whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
171172 [v => V. u[i, :] for (i, v) in enumerate (cts)]])
172173 exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
173- exprs
174174end
175175
176- function substitute_fixed_t_vars (exprs)
176+ function substitute_fixed_t_vars (model:: CasADiModel , sys, exprs)
177+ stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
178+ ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
179+ iv = MTK. get_iv (sys)
177180 for i in 1 : length (exprs)
178181 subvars = MTK. vars (exprs[i])
179182 for st in subvars
@@ -183,27 +186,28 @@ function substitute_fixed_t_vars(exprs)
183186 MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
184187 if haskey (stidxmap, x (iv))
185188 idx = stidxmap[x (iv)]
186- cv = U
189+ cv = model . U
187190 else
188191 idx = ctidxmap[x (iv)]
189- cv = V
192+ cv = model . V
190193 end
191194 exprs[i] = Symbolics. fast_substitute (exprs[i], Dict (x (t) => cv (t)[idx]))
192195 end
193196 jcosts = Symbolics. substitute (jcosts, Dict (x (t) => cv (t)[idx]))
194197 end
198+ exprs
195199end
196200
197- MTK. substitute_differentials (model:: CasADiModel , exprs, args ... ) = exprs
201+ MTK. substitute_differentials (model:: CasADiModel , sys, eqs ) = exprs
198202
199- function MTK. substitute_integral (model :: CasADiModel , exprs)
200- @unpack U, opti = model
203+ function MTK. substitute_integral (m :: CasADiModel , exprs, tspan )
204+ @unpack U, model, tₛ = m
201205 dt = U. t[2 ] - U. t[1 ]
202206 intmap = Dict ()
203207 for int in MTK. collect_applied_operators (exprs, Symbolics. Integral)
204208 op = MTK. operation (int)
205209 arg = only (arguments (MTK. value (int)))
206- lo, hi = ( op. domain. domain. left, op. domain. domain. right)
210+ lo, hi = MTK . value .(( op. domain. domain. left, op. domain. domain. right) )
207211 ! isequal ((lo, hi), tspan) &&
208212 error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." )
209213 # Approximate integral as sum.
@@ -213,11 +217,11 @@ function MTK.substitute_integral(model::CasADiModel, exprs)
213217 exprs = MTK. value .(exprs)
214218end
215219
216- function add_solve_constraints! (prob, tableau)
220+ function add_solve_constraints! (prob:: CasADiDynamicOptProblem , tableau)
217221 @unpack A, α, c = tableau
218222 @unpack model, f, p = prob
219- @unpack opti , U, V, tₛ = model
220- solver_opti = copy (opti )
223+ @unpack model , U, V, tₛ = model
224+ solver_opti = copy (model )
221225
222226 tsteps = U. t
223227 dt = tsteps[2 ] - tsteps[1 ]
@@ -258,29 +262,56 @@ function add_solve_constraints!(prob, tableau)
258262 solver_opti
259263end
260264
261- function MTK. prepare_solver ()
262- opti = add_solve_constraints (prob, tableau)
263- solver! (opti, " $solver " , plugin_options, solver_options)
265+ """
266+ CasADi Collocation solver.
267+ - solver: an optimization solver such as Ipopt. Should be given as a string or symbol in all lowercase, e.g. "ipopt"
268+ - tableau: An ODE RK tableau. Load a tableau by calling a function like `constructRK4` and may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. If this argument is not passed in, the solver will default to Radau second order.
269+ """
270+ struct CasADiCollocation <: AbstractCollocation
271+ solver:: Union{String, Symbol}
272+ tableau:: DiffEqBase.ODERKTableau
273+ end
274+ MTK. CasADiCollocation (solver, tableau = MTK. constructDefault ()) = CasADiCollocation (solver, tableau)
275+
276+ function MTK. prepare_and_optimize! (prob:: CasADiDynamicOptProblem , solver:: CasADiCollocation ; verbose = false , solver_options = Dict (), plugin_options = Dict (), kwargs... )
277+ solver_opti = add_solve_constraints! (prob, solver. tableau)
278+ verbose || (solver_options[" print_level" ] = 0 )
279+ solver! (solver_opti, " $(solver. solver) " , plugin_options, solver_options)
280+ try
281+ CasADi. solve! (solver_opti)
282+ catch ErrorException
283+ end
284+ prob. model. solver_opti = solver_opti
264285end
265- function MTK. get_U_values ()
266- U_vals = value_getter (U. u)
286+
287+ function MTK. get_U_values (model:: CasADiModel )
288+ value_getter = MTK. successful_solve (model) ? CasADi. debug_value : CasADi. value
289+ (nu, nt) = size (model. U. u)
290+ U_vals = value_getter (model. solver_opti, model. U. u)
267291 size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
268- U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 ) ] for j in 1 : length (ts) ]
292+ U_vals = [[U_vals[i, j] for i in 1 : nu ] for j in 1 : nt ]
269293end
270- function MTK. get_V_values ()
294+
295+ function MTK. get_V_values (model:: CasADiModel )
296+ value_getter = MTK. successful_solve (model) ? CasADi. debug_value : CasADi. value
297+ (nu, nt) = size (model. V. u)
298+ if nu* nt != 0
299+ V_vals = value_getter (model. solver_opti, model. V. u)
300+ size (V_vals, 2 ) == 1 && (V_vals = V_vals' )
301+ V_vals = [[V_vals[i, j] for i in 1 : nu] for j in 1 : nt]
302+ else
303+ nothing
304+ end
271305end
272- function MTK. get_t_values ()
273- ts = value_getter (tₛ) * U. t
306+
307+ function MTK. get_t_values (model:: CasADiModel )
308+ value_getter = MTK. successful_solve (model) ? CasADi. debug_value : CasADi. value
309+ ts = value_getter (model. solver_opti, model. tₛ) .* model. U. t
274310end
275311
276- function MTK. optimize_model! ()
277- try
278- sol = CasADi. solve! (opti)
279- value_getter = x -> CasADi. value (sol, x)
280- catch ErrorException
281- value_getter = x -> CasADi. debug_value (opti, x)
282- failed = true
283- end
312+ function MTK. successful_solve (m:: CasADiModel )
313+ isnothing (m. solver_opti) && return false
314+ retcode = CasADi. return_status (m. solver_opti)
315+ retcode == " Solve_Succeeded" || retcode == " Solved_To_Acceptable_Level"
284316end
285- MTK. successful_solve () = true
286317end
0 commit comments