Skip to content

Commit 02dfb13

Browse files
committed
fix: don't use eval
1 parent e9c429c commit 02dfb13

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@ function (M::MXLinearInterpolation)(τ)
4343
M.u[i] + Δ*(M.u[i + 1] - M.u[i])
4444
end
4545

46+
"""
47+
CasADiDynamicOptProblem(sys::ODESystem, u0, tspan, p; dt, steps)
48+
49+
Convert an ODESystem representing an optimal control system into a CasADi model
50+
for solving using optimization. Must provide either `dt`, the timestep between collocation
51+
points (which, along with the timespan, determines the number of points), or directly
52+
provide the number of points as `steps`.
53+
54+
The optimization variables:
55+
- a vector-of-vectors U representing the unknowns as an interpolation array
56+
- a vector-of-vectors V representing the controls as an interpolation array
57+
58+
The constraints are:
59+
- The set of user constraints passed to the ODESystem via `constraints`
60+
- The solver constraints that encode the time-stepping used by the solver
61+
"""
4662
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
4763
dt = nothing,
4864
steps = nothing,
@@ -240,16 +256,11 @@ end
240256
241257
`plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
242258
"""
243-
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol}, ode_solver::Symbol = :Default; plugin_options::Dict = Dict(), solver_options::Dict = Dict(), silent = false)
259+
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol}, tableau_getter = constructDefault; plugin_options::Dict = Dict(), solver_options::Dict = Dict(), silent = false)
244260
model = prob.model
261+
tableau = tableau_getter()
245262
opti = model.opti
246263

247-
if ode_solver == :Default
248-
tableau = MTK.constructDefault()
249-
else
250-
tableau_getter = Symbol(:construct, ode_solver)
251-
tableau = @eval Main.tableau_getter()
252-
end
253264
solver!(opti, solver, plugin_options, solver_options)
254265
add_casadi_solve_constraints!(prob, tableau)
255266
solver!(cmodel, "$solver", plugin_options, solver_options)
@@ -266,13 +277,13 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
266277
ts = value_getter(tₛ) * U.t
267278
U_vals = value_getter(U)
268279
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:length(ts)]
269-
sol = DiffEqBase.build_solution(prob, ode_solver, ts, U_vals)
280+
sol = DiffEqBase.build_solution(prob, tableau_getter, ts, U_vals)
270281

271282
input_sol = nothing
272283
if !isempty(V)
273284
V_vals = value_getter(V)
274285
V_vals = [[V_vals[i][j] for i in 1:length(V_vals)] for j in 1:length(ts)]
275-
input_sol = DiffEqBase.build_solution(prob, ode_solver, ts, V_vals)
286+
input_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, V_vals)
276287
end
277288

278289
if failed

0 commit comments

Comments
 (0)