@@ -56,9 +56,9 @@ function (M::MXLinearInterpolation)(τ)
5656end
5757
5858"""
59- CasADiDynamicOptProblem(sys::ODESystem , u0, tspan, p; dt, steps)
59+ CasADiDynamicOptProblem(sys::System , u0, tspan, p; dt, steps)
6060
61- Convert an ODESystem representing an optimal control system into a CasADi model
61+ Convert an System representing an optimal control system into a CasADi model
6262for solving using optimization. Must provide either `dt`, the timestep between collocation
6363points (which, along with the timespan, determines the number of points), or directly
6464provide the number of points as `steps`.
@@ -68,10 +68,10 @@ The optimization variables:
6868- a vector-of-vectors V representing the controls as an interpolation array
6969
7070The constraints are:
71- - The set of user constraints passed to the ODESystem via `constraints`
71+ - The set of user constraints passed to the System via `constraints`
7272- The solver constraints that encode the time-stepping used by the solver
7373"""
74- function MTK. CasADiDynamicOptProblem (sys:: ODESystem , u0map, tspan, pmap;
74+ function MTK. CasADiDynamicOptProblem (sys:: System , u0map, tspan, pmap;
7575 dt = nothing ,
7676 steps = nothing ,
7777 guesses = Dict (), kwargs... )
@@ -80,7 +80,8 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
8080 f, u0, p = MTK. process_SciMLProblem (ODEInputFunction, sys, _u0map, pmap;
8181 t = tspan != = nothing ? tspan[1 ] : tspan, output_type = MX, kwargs... )
8282
83- pmap = Dict {Any, Any} (pmap)
83+ pmap = MTK. recursive_unwrap (MTK. AnyDict (pmap))
84+ MTK. evaluate_varmap! (pmap, keys (pmap))
8485 steps, is_free_t = MTK. process_tspan (tspan, dt, steps)
8586 model = init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t)
8687
@@ -143,15 +144,15 @@ function set_casadi_bounds!(model, sys, pmap)
143144 for (i, u) in enumerate (unknowns (sys))
144145 if MTK. hasbounds (u)
145146 lo, hi = MTK. getbounds (u)
146- subject_to! (opti, Symbolics. fixpoint_sub (lo, pmap) <= U. u[i, :])
147- subject_to! (opti, U. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
147+ subject_to! (opti, Symbolics. fast_substitute (lo, pmap) <= U. u[i, :])
148+ subject_to! (opti, U. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
148149 end
149150 end
150151 for (i, v) in enumerate (MTK. unbound_inputs (sys))
151152 if MTK. hasbounds (v)
152153 lo, hi = MTK. getbounds (v)
153- subject_to! (opti, Symbolics. fixpoint_sub (lo, pmap) <= V. u[i, :])
154- subject_to! (opti, V. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
154+ subject_to! (opti, Symbolics. fast_substitute (lo, pmap) <= V. u[i, :])
155+ subject_to! (opti, V. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
155156 end
156157 end
157158end
@@ -167,15 +168,15 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
167168 @unpack opti, U, V, tₛ = model
168169
169170 iv = MTK. get_iv (sys)
170- conssys = MTK. get_constraintsystem (sys)
171- jconstraints = isnothing (conssys) ? nothing : MTK. get_constraints (conssys)
171+ jconstraints = MTK. get_constraints (sys)
172172 (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
173173
174174 stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
175175 ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
176- cons_unknowns = map (MTK. default_toterm, unknowns (conssys))
176+ cons_dvs, cons_ps = MTK. process_constraint_system (
177+ jconstraints, Set (unknowns (sys)), parameters (sys), iv; validate = false )
177178
178- auxmap = Dict ([u => MTK. default_toterm (MTK. value (u)) for u in unknowns (conssys) ])
179+ auxmap = Dict ([u => MTK. default_toterm (MTK. value (u)) for u in cons_dvs ])
179180 jconstraints = substitute_casadi_vars (model, sys, pmap, jconstraints; is_free_t, auxmap)
180181 # Manually substitute fixed-t variables
181182 for (i, cons) in enumerate (jconstraints)
207208
208209function add_cost_function! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
209210 @unpack opti, U, V, tₛ = model
210- jcosts = copy (MTK. get_costs (sys))
211- consolidate = MTK. get_consolidate (sys)
212- if isnothing (jcosts) || isempty (jcosts)
211+ jcosts = cost (sys)
212+ if Symbolics. _iszero (jcosts)
213213 minimize! (opti, MX (0 ))
214214 return
215215 end
@@ -218,24 +218,22 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
218218 stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
219219 ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
220220
221- jcosts = substitute_casadi_vars (model, sys, pmap, jcosts; is_free_t)
221+ jcosts = substitute_casadi_vars (model, sys, pmap, [ jcosts] ; is_free_t)[ 1 ]
222222 # Substitute fixed-time variables.
223- for i in 1 : length (jcosts)
224- costvars = MTK. vars (jcosts[i])
225- for st in costvars
226- MTK. iscall (st) || continue
227- x = operation (st)
228- t = only (arguments (st))
229- MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
230- if haskey (stidxmap, x (iv))
231- idx = stidxmap[x (iv)]
232- cv = U
233- else
234- idx = ctidxmap[x (iv)]
235- cv = V
236- end
237- jcosts[i] = Symbolics. substitute (jcosts[i], Dict (x (t) => cv (t)[idx]))
223+ costvars = MTK. vars (jcosts)
224+ for st in costvars
225+ MTK. iscall (st) || continue
226+ x = operation (st)
227+ t = only (arguments (st))
228+ MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
229+ if haskey (stidxmap, x (iv))
230+ idx = stidxmap[x (iv)]
231+ cv = U
232+ else
233+ idx = ctidxmap[x (iv)]
234+ cv = V
238235 end
236+ jcosts = Symbolics. substitute (jcosts, Dict (x (t) => cv (t)[idx]))
239237 end
240238
241239 dt = U. t[2 ] - U. t[1 ]
@@ -249,9 +247,9 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
249247 # Approximate integral as sum.
250248 intmap[int] = dt * tₛ * sum (arg)
251249 end
252- jcosts = map (c -> Symbolics. substitute (c , intmap), jcosts )
253- jcosts = MTK. value . (jcosts)
254- minimize! (opti, MX (MTK . value ( consolidate ( jcosts)) ))
250+ jcosts = Symbolics. substitute (jcosts , intmap)
251+ jcosts = MTK. value (jcosts)
252+ minimize! (opti, MX (jcosts))
255253end
256254
257255function substitute_casadi_vars (
@@ -264,20 +262,20 @@ function substitute_casadi_vars(
264262 x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
265263 c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
266264
267- exprs = map (c -> Symbolics. fixpoint_sub (c, auxmap), exprs)
268- exprs = map (c -> Symbolics. fixpoint_sub (c, Dict (pmap)), exprs)
265+ exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
266+ exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
269267 # tf means different things in different contexts; a [tf] in a cost function
270268 # should be tₛ, while a x(tf) should translate to x[1]
271269 if is_free_t
272270 free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
273271 [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
274- exprs = map (c -> Symbolics. fixpoint_sub (c, free_t_map), exprs)
272+ exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
275273 end
276274
277275 # for variables like x(t)
278276 whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
279277 [v => V. u[i, :] for (i, v) in enumerate (cts)]])
280- exprs = map (c -> Symbolics. fixpoint_sub (c, whole_interval_map), exprs)
278+ exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
281279 exprs
282280end
283281
0 commit comments