@@ -56,9 +56,9 @@ function (M::MXLinearInterpolation)(τ)
5656end 
5757
5858""" 
59-     CasADiDynamicOptProblem(sys::ODESystem , u0, tspan, p; dt, steps) 
59+     CasADiDynamicOptProblem(sys::System , u0, tspan, p; dt, steps) 
6060
61- Convert an ODESystem  representing an optimal control system into a CasADi model 
61+ Convert an System  representing an optimal control system into a CasADi model 
6262for solving using optimization. Must provide either `dt`, the timestep between collocation  
6363points (which, along with the timespan, determines the number of points), or directly  
6464provide the number of points as `steps`. 
@@ -68,10 +68,10 @@ The optimization variables:
6868- a vector-of-vectors V representing the controls as an interpolation array 
6969
7070The constraints are: 
71- - The set of user constraints passed to the ODESystem  via `constraints` 
71+ - The set of user constraints passed to the System  via `constraints` 
7272- The solver constraints that encode the time-stepping used by the solver 
7373""" 
74- function  MTK. CasADiDynamicOptProblem (sys:: ODESystem , u0map, tspan, pmap;
74+ function  MTK. CasADiDynamicOptProblem (sys:: System , u0map, tspan, pmap;
7575        dt =  nothing ,
7676        steps =  nothing ,
7777        guesses =  Dict (), kwargs... )
@@ -80,7 +80,8 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
8080    f, u0, p =  MTK. process_SciMLProblem (ODEInputFunction, sys, _u0map, pmap;
8181        t =  tspan != =  nothing  ?  tspan[1 ] :  tspan, output_type =  MX, kwargs... )
8282
83-     pmap =  Dict {Any, Any} (pmap)
83+     pmap =  MTK. recursive_unwrap (MTK. AnyDict (pmap))
84+     MTK. evaluate_varmap! (pmap, keys (pmap))
8485    steps, is_free_t =  MTK. process_tspan (tspan, dt, steps)
8586    model =  init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t)
8687
@@ -143,15 +144,15 @@ function set_casadi_bounds!(model, sys, pmap)
143144    for  (i, u) in  enumerate (unknowns (sys))
144145        if  MTK. hasbounds (u)
145146            lo, hi =  MTK. getbounds (u)
146-             subject_to! (opti, Symbolics. fixpoint_sub (lo, pmap) <=  U. u[i, :])
147-             subject_to! (opti, U. u[i, :] <=  Symbolics. fixpoint_sub (hi, pmap))
147+             subject_to! (opti, Symbolics. fast_substitute (lo, pmap) <=  U. u[i, :])
148+             subject_to! (opti, U. u[i, :] <=  Symbolics. fast_substitute (hi, pmap))
148149        end 
149150    end 
150151    for  (i, v) in  enumerate (MTK. unbound_inputs (sys))
151152        if  MTK. hasbounds (v)
152153            lo, hi =  MTK. getbounds (v)
153-             subject_to! (opti, Symbolics. fixpoint_sub (lo, pmap) <=  V. u[i, :])
154-             subject_to! (opti, V. u[i, :] <=  Symbolics. fixpoint_sub (hi, pmap))
154+             subject_to! (opti, Symbolics. fast_substitute (lo, pmap) <=  V. u[i, :])
155+             subject_to! (opti, V. u[i, :] <=  Symbolics. fast_substitute (hi, pmap))
155156        end 
156157    end 
157158end 
@@ -167,15 +168,15 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
167168    @unpack  opti, U, V, tₛ =  model
168169
169170    iv =  MTK. get_iv (sys)
170-     conssys =  MTK. get_constraintsystem (sys)
171-     jconstraints =  isnothing (conssys) ?  nothing  :  MTK. get_constraints (conssys)
171+     jconstraints =  MTK. get_constraints (sys)
172172    (isnothing (jconstraints) ||  isempty (jconstraints)) &&  return  nothing 
173173
174174    stidxmap =  Dict ([v =>  i for  (i, v) in  enumerate (unknowns (sys))])
175175    ctidxmap =  Dict ([v =>  i for  (i, v) in  enumerate (MTK. unbound_inputs (sys))])
176-     cons_unknowns =  map (MTK. default_toterm, unknowns (conssys))
176+     cons_dvs, cons_ps =  MTK. process_constraint_system (
177+         jconstraints, Set (unknowns (sys)), parameters (sys), iv; validate =  false )
177178
178-     auxmap =  Dict ([u =>  MTK. default_toterm (MTK. value (u)) for  u in  unknowns (conssys) ])
179+     auxmap =  Dict ([u =>  MTK. default_toterm (MTK. value (u)) for  u in  cons_dvs ])
179180    jconstraints =  substitute_casadi_vars (model, sys, pmap, jconstraints; is_free_t, auxmap)
180181    #  Manually substitute fixed-t variables
181182    for  (i, cons) in  enumerate (jconstraints)
207208
208209function  add_cost_function! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
209210    @unpack  opti, U, V, tₛ =  model
210-     jcosts =  copy (MTK. get_costs (sys))
211-     consolidate =  MTK. get_consolidate (sys)
212-     if  isnothing (jcosts) ||  isempty (jcosts)
211+     jcosts =  cost (sys)
212+     if  Symbolics. _iszero (jcosts)
213213        minimize! (opti, MX (0 ))
214214        return 
215215    end 
@@ -218,24 +218,22 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
218218    stidxmap =  Dict ([v =>  i for  (i, v) in  enumerate (unknowns (sys))])
219219    ctidxmap =  Dict ([v =>  i for  (i, v) in  enumerate (MTK. unbound_inputs (sys))])
220220
221-     jcosts =  substitute_casadi_vars (model, sys, pmap, jcosts; is_free_t)
221+     jcosts =  substitute_casadi_vars (model, sys, pmap, [ jcosts] ; is_free_t)[ 1 ] 
222222    #  Substitute fixed-time variables.
223-     for  i in  1 : length (jcosts)
224-         costvars =  MTK. vars (jcosts[i])
225-         for  st in  costvars
226-             MTK. iscall (st) ||  continue 
227-             x =  operation (st)
228-             t =  only (arguments (st))
229-             MTK. symbolic_type (t) ===  MTK. NotSymbolic () ||  continue 
230-             if  haskey (stidxmap, x (iv))
231-                 idx =  stidxmap[x (iv)]
232-                 cv =  U
233-             else 
234-                 idx =  ctidxmap[x (iv)]
235-                 cv =  V
236-             end 
237-             jcosts[i] =  Symbolics. substitute (jcosts[i], Dict (x (t) =>  cv (t)[idx]))
223+     costvars =  MTK. vars (jcosts)
224+     for  st in  costvars
225+         MTK. iscall (st) ||  continue 
226+         x =  operation (st)
227+         t =  only (arguments (st))
228+         MTK. symbolic_type (t) ===  MTK. NotSymbolic () ||  continue 
229+         if  haskey (stidxmap, x (iv))
230+             idx =  stidxmap[x (iv)]
231+             cv =  U
232+         else 
233+             idx =  ctidxmap[x (iv)]
234+             cv =  V
238235        end 
236+         jcosts =  Symbolics. substitute (jcosts, Dict (x (t) =>  cv (t)[idx]))
239237    end 
240238
241239    dt =  U. t[2 ] -  U. t[1 ]
@@ -249,9 +247,9 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
249247        #  Approximate integral as sum.
250248        intmap[int] =  dt *  tₛ *  sum (arg)
251249    end 
252-     jcosts =  map (c  ->   Symbolics. substitute (c , intmap), jcosts )
253-     jcosts =  MTK. value . (jcosts)
254-     minimize! (opti, MX (MTK . value ( consolidate ( jcosts)) ))
250+     jcosts =  Symbolics. substitute (jcosts , intmap)
251+     jcosts =  MTK. value (jcosts)
252+     minimize! (opti, MX (jcosts))
255253end 
256254
257255function  substitute_casadi_vars (
@@ -264,20 +262,20 @@ function substitute_casadi_vars(
264262    x_ops =  [MTK. operation (MTK. unwrap (st)) for  st in  sts]
265263    c_ops =  [MTK. operation (MTK. unwrap (ct)) for  ct in  cts]
266264
267-     exprs =  map (c ->  Symbolics. fixpoint_sub (c, auxmap), exprs)
268-     exprs =  map (c ->  Symbolics. fixpoint_sub (c, Dict (pmap)), exprs)
265+     exprs =  map (c ->  Symbolics. fast_substitute (c, auxmap), exprs)
266+     exprs =  map (c ->  Symbolics. fast_substitute (c, Dict (pmap)), exprs)
269267    #  tf means different things in different contexts; a [tf] in a cost function
270268    #  should be tₛ, while a x(tf) should translate to x[1]
271269    if  is_free_t
272270        free_t_map =  Dict ([[x (tₛ) =>  U. u[i, end ] for  (i, x) in  enumerate (x_ops)];
273271                           [c (tₛ) =>  V. u[i, end ] for  (i, c) in  enumerate (c_ops)]])
274-         exprs =  map (c ->  Symbolics. fixpoint_sub (c, free_t_map), exprs)
272+         exprs =  map (c ->  Symbolics. fast_substitute (c, free_t_map), exprs)
275273    end 
276274
277275    #  for variables like x(t)
278276    whole_interval_map =  Dict ([[v =>  U. u[i, :] for  (i, v) in  enumerate (sts)];
279277                               [v =>  V. u[i, :] for  (i, v) in  enumerate (cts)]])
280-     exprs =  map (c ->  Symbolics. fixpoint_sub (c, whole_interval_map), exprs)
278+     exprs =  map (c ->  Symbolics. fast_substitute (c, whole_interval_map), exprs)
281279    exprs
282280end 
283281
0 commit comments