@@ -43,6 +43,22 @@ function (M::MXLinearInterpolation)(τ)
4343 M. u[i] + Δ* (M. u[i + 1 ] - M. u[i])
4444end
4545
46+ """
47+ CasADiDynamicOptProblem(sys::ODESystem, u0, tspan, p; dt, steps)
48+
49+ Convert an ODESystem representing an optimal control system into a CasADi model
50+ for solving using optimization. Must provide either `dt`, the timestep between collocation
51+ points (which, along with the timespan, determines the number of points), or directly
52+ provide the number of points as `steps`.
53+
54+ The optimization variables:
55+ - a vector-of-vectors U representing the unknowns as an interpolation array
56+ - a vector-of-vectors V representing the controls as an interpolation array
57+
58+ The constraints are:
59+ - The set of user constraints passed to the ODESystem via `constraints`
60+ - The solver constraints that encode the time-stepping used by the solver
61+ """
4662function MTK. CasADiDynamicOptProblem (sys:: ODESystem , u0map, tspan, pmap;
4763 dt = nothing ,
4864 steps = nothing ,
@@ -240,16 +256,11 @@ end
240256
241257`plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
242258"""
243- function DiffEqBase. solve (prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} , ode_solver :: Symbol = :Default ; plugin_options:: Dict = Dict (), solver_options:: Dict = Dict (), silent = false )
259+ function DiffEqBase. solve (prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} , tableau_getter = constructDefault ; plugin_options:: Dict = Dict (), solver_options:: Dict = Dict (), silent = false )
244260 model = prob. model
261+ tableau = tableau_getter ()
245262 opti = model. opti
246263
247- if ode_solver == :Default
248- tableau = MTK. constructDefault ()
249- else
250- tableau_getter = Symbol (:construct , ode_solver)
251- tableau = @eval Main. tableau_getter ()
252- end
253264 solver! (opti, solver, plugin_options, solver_options)
254265 add_casadi_solve_constraints! (prob, tableau)
255266 solver! (cmodel, " $solver " , plugin_options, solver_options)
@@ -266,13 +277,13 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
266277 ts = value_getter (tₛ) * U. t
267278 U_vals = value_getter (U)
268279 U_vals = [[U_vals[i][j] for i in 1 : length (U_vals)] for j in 1 : length (ts)]
269- sol = DiffEqBase. build_solution (prob, ode_solver , ts, U_vals)
280+ sol = DiffEqBase. build_solution (prob, tableau_getter , ts, U_vals)
270281
271282 input_sol = nothing
272283 if ! isempty (V)
273284 V_vals = value_getter (V)
274285 V_vals = [[V_vals[i][j] for i in 1 : length (V_vals)] for j in 1 : length (ts)]
275- input_sol = DiffEqBase. build_solution (prob, ode_solver , ts, V_vals)
286+ input_sol = DiffEqBase. build_solution (prob, tableau_getter , ts, V_vals)
276287 end
277288
278289 if failed
0 commit comments