@@ -18,12 +18,17 @@ struct MXLinearInterpolation
1818    dt:: Float64 
1919end 
2020
21- struct  CasADiModel
22-     opti :: Opti 
21+ mutable  struct
22+     model :: Opti 
2323    U:: MXLinearInterpolation 
2424    V:: MXLinearInterpolation 
2525    tₛ:: MX 
2626    is_free_final:: Bool 
27+     solver_opti:: Union{Nothing, Opti} 
28+ 
29+     function  CasADiModel (opti, U, V, tₛ, is_free_final, solver_opti =  nothing )
30+         new (opti, U, V, tₛ, is_free_final, solver_opti)
31+     end 
2732end 
2833
2934struct  CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} < :
@@ -74,24 +79,27 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
7479        dt =  nothing ,
7580        steps =  nothing ,
7681        guesses =  Dict (), kwargs... )
77-     process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
82+     MTK . process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
7883end 
7984
80- MTK. generate_internal_model (:: Type{CasADiModel} ) =  CasADi. opti ()
85+ MTK. generate_internal_model (:: Type{CasADiModel} ) =  CasADi. Opti ()
86+ MTK. generate_time_variable! (opti:: Opti , args... ) =  nothing 
8187
82- function  MTK. generate_state_variable (model:: Opti , u0, ns, nt, tsteps)
88+ function  MTK. generate_state_variable! (model:: Opti , u0, ns, tsteps)
89+     nt =  length (tsteps)
8390    U =  CasADi. variable! (model, ns, nt)
84-     set_initial! (opti , U, DM (repeat (u0, 1 , steps )))
91+     set_initial! (model , U, DM (repeat (u0, 1 , nt )))
8592    MXLinearInterpolation (U, tsteps, tsteps[2 ] -  tsteps[1 ])
8693end 
8794
88- function  MTK. generate_input_variable (model:: Opti , c0, nc, nt, tsteps)
95+ function  MTK. generate_input_variable! (model:: Opti , c0, nc, tsteps)
96+     nt =  length (tsteps)
8997    V =  CasADi. variable! (model, nc, nt)
90-     ! isempty (c0) &&  set_initial! (opti , V, DM (repeat (c0, 1 , steps )))
98+     ! isempty (c0) &&  set_initial! (model , V, DM (repeat (c0, 1 , nt )))
9199    MXLinearInterpolation (V, tsteps, tsteps[2 ] -  tsteps[1 ])
92100end 
93101
94- function  MTK. generate_timescale (model:: Opti , guess, is_free_t)
102+ function  MTK. generate_timescale!  (model:: Opti , guess, is_free_t)
95103    if  is_free_t
96104        tₛ =  variable! (model)
97105        set_initial! (model, tₛ, guess)
@@ -102,78 +110,73 @@ function MTK.generate_timescale(model::Opti, guess, is_free_t)
102110    end 
103111end 
104112
105- function  MTK. add_constraint! (model:: CasADiModel , expr)
106-     @unpack  opti =  model
107-     if  cons isa  Equation
108-         subject_to! (opti, expr. lhs -  expr. rhs ==  0 )
109-     elseif  cons. relational_op ===  Symbolics. geq
110-         subject_to! (opti, expr. lhs -  expr. rhs ≥  0 )
113+ function  MTK. add_constraint! (m:: CasADiModel , expr)
114+     if  expr isa  Equation
115+         subject_to! (m. model, expr. lhs -  expr. rhs ==  0 )
116+     elseif  expr. relational_op ===  Symbolics. geq
117+         subject_to! (m. model, expr. lhs -  expr. rhs ≥  0 )
111118    else 
112-         subject_to! (opti , expr. lhs -  expr. rhs ≤  0 )
119+         subject_to! (m . model , expr. lhs -  expr. rhs ≤  0 )
113120    end 
114121end 
115- MTK. set_objective! (model :: CasADiModel , expr) =  minimize! (model . opti , MX (expr))
122+ MTK. set_objective! (m :: CasADiModel , expr) =  minimize! (m . model , MX (expr))
116123
117- function  MTK. set_variable_bounds! (model , sys, pmap, tf)
118-     @unpack  opti , U, V =  model 
124+ function  MTK. set_variable_bounds! (m :: CasADiModel , sys, pmap, tf)
125+     @unpack  model , U, tₛ,  V =  m 
119126    for  (i, u) in  enumerate (unknowns (sys))
120127        if  MTK. hasbounds (u)
121128            lo, hi =  MTK. getbounds (u)
122-             subject_to! (opti , Symbolics. fixpoint_sub (lo, pmap) <=  U. u[i, :])
123-             subject_to! (opti , U. u[i, :] <=  Symbolics. fixpoint_sub (hi, pmap))
129+             subject_to! (model , Symbolics. fixpoint_sub (lo, pmap) <=  U. u[i, :])
130+             subject_to! (model , U. u[i, :] <=  Symbolics. fixpoint_sub (hi, pmap))
124131        end 
125132    end 
126133    for  (i, v) in  enumerate (MTK. unbound_inputs (sys))
127134        if  MTK. hasbounds (v)
128135            lo, hi =  MTK. getbounds (v)
129-             subject_to! (opti , Symbolics. fixpoint_sub (lo, pmap) <=  V. u[i, :])
130-             subject_to! (opti , V. u[i, :] <=  Symbolics. fixpoint_sub (hi, pmap))
136+             subject_to! (model , Symbolics. fixpoint_sub (lo, pmap) <=  V. u[i, :])
137+             subject_to! (model , V. u[i, :] <=  Symbolics. fixpoint_sub (hi, pmap))
131138        end 
132139    end 
133140    if  MTK. symbolic_type (tf) ===  MTK. ScalarSymbolic () &&  hasbounds (tf)
134141        lo, hi =  MTK. getbounds (tf)
135-         subject_to! (opti, model . tₛ >=  lo)
136-         subject_to! (opti, model . tₛ <=  hi)
142+         subject_to! (model,  tₛ >=  lo)
143+         subject_to! (model,  tₛ <=  hi)
137144    end 
138145end 
139146
140- function  MTK. add_initial_constraints! (model :: CasADiModel , u0, u0_idxs)
141-     @unpack  opti , U =  model 
147+ function  MTK. add_initial_constraints! (m :: CasADiModel , u0, u0_idxs, args ... )
148+     @unpack  model , U =  m 
142149    for  i in  u0_idxs
143-         subject_to! (opti , U. u[i, 1 ] ==  u0[i])
150+         subject_to! (model , U. u[i, 1 ] ==  u0[i])
144151    end 
145152end 
146153
147- function  MTK. substitute_model_vars (
148-         model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict  =  Dict (), is_free_t)
149-     @unpack  opti, U, V, tₛ =  model
154+ function  MTK. substitute_model_vars (m:: CasADiModel , sys, exprs, tspan)
155+     @unpack  model, U, V, tₛ =  m
150156    iv =  MTK. get_iv (sys)
151157    sts =  unknowns (sys)
152158    cts =  MTK. unbound_inputs (sys)
153- 
154159    x_ops =  [MTK. operation (MTK. unwrap (st)) for  st in  sts]
155160    c_ops =  [MTK. operation (MTK. unwrap (ct)) for  ct in  cts]
156- 
157-     exprs =  map (c ->  Symbolics. fast_substitute (c, auxmap), exprs)
158-     exprs =  map (c ->  Symbolics. fast_substitute (c, Dict (pmap)), exprs)
159-     #  tf means different things in different contexts; a [tf] in a cost function
160-     #  should be tₛ, while a x(tf) should translate to x[1]
161-     if  is_free_t
162-         free_t_map =  Dict ([[x (tₛ) =>  U. u[i, end ] for  (i, x) in  enumerate (x_ops)];
163-                            [c (tₛ) =>  V. u[i, end ] for  (i, c) in  enumerate (c_ops)]])
161+     (ti, tf) =  tspan
162+     if  MTK. is_free_final (m)
163+         _tf =  tₛ +  ti
164+         exprs =  map (c ->  Symbolics. fast_substitute (c, Dict (tf =>  _tf)), exprs)
165+         free_t_map =  Dict ([[x (_tf) =>  U. u[i, end ] for  (i, x) in  enumerate (x_ops)];
166+                            [c (_tf) =>  V. u[i, end ] for  (i, c) in  enumerate (c_ops)]])
164167        exprs =  map (c ->  Symbolics. fast_substitute (c, free_t_map), exprs)
165168    end 
166169
167-     exprs =  substitute_fixed_t_vars (exprs)
168- 
169-     #  for variables like x(t)
170+     exprs =  substitute_fixed_t_vars (m, sys, exprs)
170171    whole_interval_map =  Dict ([[v =>  U. u[i, :] for  (i, v) in  enumerate (sts)];
171172                               [v =>  V. u[i, :] for  (i, v) in  enumerate (cts)]])
172173    exprs =  map (c ->  Symbolics. fast_substitute (c, whole_interval_map), exprs)
173-     exprs
174174end 
175175
176- function  substitute_fixed_t_vars (exprs)
176+ function  substitute_fixed_t_vars (model:: CasADiModel , sys, exprs)
177+     stidxmap =  Dict ([v =>  i for  (i, v) in  enumerate (unknowns (sys))])
178+     ctidxmap =  Dict ([v =>  i for  (i, v) in  enumerate (MTK. unbound_inputs (sys))])
179+     iv =  MTK. get_iv (sys)
177180    for  i in  1 : length (exprs)
178181        subvars =  MTK. vars (exprs[i])
179182        for  st in  subvars 
@@ -183,26 +186,27 @@ function substitute_fixed_t_vars(exprs)
183186            MTK. symbolic_type (t) ===  MTK. NotSymbolic () ||  continue 
184187            if  haskey (stidxmap, x (iv))
185188                idx =  stidxmap[x (iv)]
186-                 cv =  U
189+                 cv =  model . U
187190            else 
188191                idx =  ctidxmap[x (iv)]
189-                 cv =  V
192+                 cv =  model . V
190193            end 
191194            exprs[i] =  Symbolics. fast_substitute (exprs[i], Dict (x (t) =>  cv (t)[idx]))
192195        end 
193196    end 
197+     exprs
194198end 
195199
196- MTK. substitute_differentials (model:: CasADiModel , exprs, args ... ) =  exprs
200+ MTK. substitute_differentials (model:: CasADiModel , sys, eqs ) =  exprs
197201
198- function  MTK. substitute_integral (model :: CasADiModel , exprs)
199-     @unpack  U, opti  =  model 
202+ function  MTK. substitute_integral (m :: CasADiModel , exprs, tspan )
203+     @unpack  U, model, tₛ  =  m 
200204    dt =  U. t[2 ] -  U. t[1 ]
201205    intmap =  Dict ()
202206    for  int in  MTK. collect_applied_operators (exprs, Symbolics. Integral)
203207        op =  MTK. operation (int)
204208        arg =  only (arguments (MTK. value (int)))
205-         lo, hi =  ( op. domain. domain. left, op. domain. domain. right)
209+         lo, hi =  MTK . value .(( op. domain. domain. left, op. domain. domain. right) )
206210        ! isequal ((lo, hi), tspan) && 
207211            error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." 
208212        #  Approximate integral as sum.
@@ -212,11 +216,11 @@ function MTK.substitute_integral(model::CasADiModel, exprs)
212216    exprs =  MTK. value .(exprs)
213217end 
214218
215- function  add_solve_constraints! (prob, tableau)
219+ function  add_solve_constraints! (prob:: CasADiDynamicOptProblem , tableau)
216220    @unpack  A, α, c =  tableau
217221    @unpack  model, f, p =  prob
218-     @unpack  opti , U, V, tₛ =  model
219-     solver_opti =  copy (opti )
222+     @unpack  model , U, V, tₛ =  model
223+     solver_opti =  copy (model )
220224
221225    tsteps =  U. t
222226    dt =  tsteps[2 ] -  tsteps[1 ]
@@ -257,29 +261,56 @@ function add_solve_constraints!(prob, tableau)
257261    solver_opti
258262end 
259263
260- function  MTK. prepare_solver ()
261-     opti =  add_solve_constraints (prob, tableau)
262-     solver! (opti, " $solver " 
264+ """ 
265+ CasADi Collocation solver. 
266+ - solver: an optimization solver such as Ipopt. Should be given as a string or symbol in all lowercase, e.g. "ipopt" 
267+ - tableau: An ODE RK tableau. Load a tableau by calling a function like `constructRK4` and may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. If this argument is not passed in, the solver will default to Radau second order. 
268+ """ 
269+ struct  CasADiCollocation <:  AbstractCollocation 
270+     solver:: Union{String, Symbol} 
271+     tableau:: DiffEqBase.ODERKTableau 
272+ end 
273+ MTK. CasADiCollocation (solver, tableau =  MTK. constructDefault ()) =  CasADiCollocation (solver, tableau)
274+ 
275+ function  MTK. prepare_and_optimize! (prob:: CasADiDynamicOptProblem , solver:: CasADiCollocation ; verbose =  false , solver_options =  Dict (), plugin_options =  Dict (), kwargs... )
276+     solver_opti =  add_solve_constraints! (prob, solver. tableau)
277+     verbose ||  (solver_options[" print_level" =  0 )
278+     solver! (solver_opti, " $(solver. solver) " 
279+     try 
280+         CasADi. solve! (solver_opti)
281+     catch  ErrorException
282+     end 
283+     prob. model. solver_opti =  solver_opti
263284end 
264- function  MTK. get_U_values ()
265-     U_vals =  value_getter (U. u)
285+ 
286+ function  MTK. get_U_values (model:: CasADiModel )
287+     value_getter =  MTK. successful_solve (model) ?  CasADi. debug_value :  CasADi. value
288+     (nu, nt) =  size (model. U. u)
289+     U_vals =  value_getter (model. solver_opti, model. U. u)
266290    size (U_vals, 2 ) ==  1  &&  (U_vals =  U_vals' )
267-     U_vals =  [[U_vals[i, j] for  i in  1 : size (U_vals,  1 ) ] for  j in  1 : length (ts) ]
291+     U_vals =  [[U_vals[i, j] for  i in  1 : nu ] for  j in  1 : nt ]
268292end 
269- function  MTK. get_V_values ()
293+ 
294+ function  MTK. get_V_values (model:: CasADiModel )
295+     value_getter =  MTK. successful_solve (model) ?  CasADi. debug_value :  CasADi. value
296+     (nu, nt) =  size (model. V. u)
297+     if  nu* nt !=  0 
298+         V_vals =  value_getter (model. solver_opti, model. V. u)
299+         size (V_vals, 2 ) ==  1  &&  (V_vals =  V_vals' )
300+         V_vals =  [[V_vals[i, j] for  i in  1 : nu] for  j in  1 : nt]
301+     else 
302+         nothing 
303+     end 
270304end 
271- function  MTK. get_t_values ()
272-     ts =  value_getter (tₛ) *  U. t
305+ 
306+ function  MTK. get_t_values (model:: CasADiModel )
307+     value_getter =  MTK. successful_solve (model) ?  CasADi. debug_value :  CasADi. value
308+     ts =  value_getter (model. solver_opti, model. tₛ) .*  model. U. t
273309end 
274310
275- function  MTK. optimize_model! ()
276-     try 
277-         sol =  CasADi. solve! (opti)
278-         value_getter =  x ->  CasADi. value (sol, x)
279-     catch  ErrorException
280-         value_getter =  x ->  CasADi. debug_value (opti, x)
281-         failed =  true 
282-     end 
311+ function  MTK. successful_solve (m:: CasADiModel ) 
312+     isnothing (m. solver_opti) &&  return  false 
313+     retcode =  CasADi. return_status (m. solver_opti)
314+     retcode ==  " Solve_Succeeded" ||  retcode ==  " Solved_To_Acceptable_Level" 
283315end 
284- MTK. successful_solve () =  true 
285316end 
0 commit comments