Skip to content

Commit 9ec68e4

Browse files
committed
correctly implement interface
1 parent 4d45266 commit 9ec68e4

File tree

4 files changed

+164
-127
lines changed

4 files changed

+164
-127
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 103 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ struct MXLinearInterpolation
1818
dt::Float64
1919
end
2020

21-
struct CasADiModel
22-
opti::Opti
21+
mutable struct CasADiModel
22+
model::Opti
2323
U::MXLinearInterpolation
2424
V::MXLinearInterpolation
2525
tₛ::MX
2626
is_free_final::Bool
27+
solver_opti::Union{Nothing, Opti}
28+
29+
function CasADiModel(opti, U, V, tₛ, is_free_final, solver_opti = nothing)
30+
new(opti, U, V, tₛ, is_free_final, solver_opti)
31+
end
2732
end
2833

2934
struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
@@ -74,24 +79,27 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
7479
dt = nothing,
7580
steps = nothing,
7681
guesses = Dict(), kwargs...)
77-
process_DynamicOptProblem(CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs...)
82+
MTK.process_DynamicOptProblem(CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs...)
7883
end
7984

80-
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.opti()
85+
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.Opti()
86+
MTK.generate_time_variable!(opti::Opti, args...) = nothing
8187

82-
function MTK.generate_state_variable(model::Opti, u0, ns, nt, tsteps)
88+
function MTK.generate_state_variable!(model::Opti, u0, ns, tsteps)
89+
nt = length(tsteps)
8390
U = CasADi.variable!(model, ns, nt)
84-
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
91+
set_initial!(model, U, DM(repeat(u0, 1, nt)))
8592
MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
8693
end
8794

88-
function MTK.generate_input_variable(model::Opti, c0, nc, nt, tsteps)
95+
function MTK.generate_input_variable!(model::Opti, c0, nc, tsteps)
96+
nt = length(tsteps)
8997
V = CasADi.variable!(model, nc, nt)
90-
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))
98+
!isempty(c0) && set_initial!(model, V, DM(repeat(c0, 1, nt)))
9199
MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
92100
end
93101

94-
function MTK.generate_timescale(model::Opti, guess, is_free_t)
102+
function MTK.generate_timescale!(model::Opti, guess, is_free_t)
95103
if is_free_t
96104
tₛ = variable!(model)
97105
set_initial!(model, tₛ, guess)
@@ -102,78 +110,73 @@ function MTK.generate_timescale(model::Opti, guess, is_free_t)
102110
end
103111
end
104112

105-
function MTK.add_constraint!(model::CasADiModel, expr)
106-
@unpack opti = model
107-
if cons isa Equation
108-
subject_to!(opti, expr.lhs - expr.rhs == 0)
109-
elseif cons.relational_op === Symbolics.geq
110-
subject_to!(opti, expr.lhs - expr.rhs 0)
113+
function MTK.add_constraint!(m::CasADiModel, expr)
114+
if expr isa Equation
115+
subject_to!(m.model, expr.lhs - expr.rhs == 0)
116+
elseif expr.relational_op === Symbolics.geq
117+
subject_to!(m.model, expr.lhs - expr.rhs 0)
111118
else
112-
subject_to!(opti, expr.lhs - expr.rhs 0)
119+
subject_to!(m.model, expr.lhs - expr.rhs 0)
113120
end
114121
end
115-
MTK.set_objective!(model::CasADiModel, expr) = minimize!(model.opti, MX(expr))
122+
MTK.set_objective!(m::CasADiModel, expr) = minimize!(m.model, MX(expr))
116123

117-
function MTK.set_variable_bounds!(model, sys, pmap, tf)
118-
@unpack opti, U, V = model
124+
function MTK.set_variable_bounds!(m::CasADiModel, sys, pmap, tf)
125+
@unpack model, U, tₛ, V = m
119126
for (i, u) in enumerate(unknowns(sys))
120127
if MTK.hasbounds(u)
121128
lo, hi = MTK.getbounds(u)
122-
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
123-
subject_to!(opti, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
129+
subject_to!(model, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
130+
subject_to!(model, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
124131
end
125132
end
126133
for (i, v) in enumerate(MTK.unbound_inputs(sys))
127134
if MTK.hasbounds(v)
128135
lo, hi = MTK.getbounds(v)
129-
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
130-
subject_to!(opti, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
136+
subject_to!(model, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
137+
subject_to!(model, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
131138
end
132139
end
133140
if MTK.symbolic_type(tf) === MTK.ScalarSymbolic() && hasbounds(tf)
134141
lo, hi = MTK.getbounds(tf)
135-
subject_to!(opti, model.tₛ >= lo)
136-
subject_to!(opti, model.tₛ <= hi)
142+
subject_to!(model, tₛ >= lo)
143+
subject_to!(model, tₛ <= hi)
137144
end
138145
end
139146

140-
function MTK.add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
141-
@unpack opti, U = model
147+
function MTK.add_initial_constraints!(m::CasADiModel, u0, u0_idxs, args...)
148+
@unpack model, U = m
142149
for i in u0_idxs
143-
subject_to!(opti, U.u[i, 1] == u0[i])
150+
subject_to!(model, U.u[i, 1] == u0[i])
144151
end
145152
end
146153

147-
function MTK.substitute_model_vars(
148-
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
149-
@unpack opti, U, V, tₛ = model
154+
function MTK.substitute_model_vars(m::CasADiModel, sys, exprs, tspan)
155+
@unpack model, U, V, tₛ = m
150156
iv = MTK.get_iv(sys)
151157
sts = unknowns(sys)
152158
cts = MTK.unbound_inputs(sys)
153-
154159
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
155160
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
156-
157-
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
158-
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
159-
# tf means different things in different contexts; a [tf] in a cost function
160-
# should be tₛ, while a x(tf) should translate to x[1]
161-
if is_free_t
162-
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
163-
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
161+
(ti, tf) = tspan
162+
if MTK.is_free_final(m)
163+
_tf = tₛ + ti
164+
exprs = map(c -> Symbolics.fast_substitute(c, Dict(tf => _tf)), exprs)
165+
free_t_map = Dict([[x(_tf) => U.u[i, end] for (i, x) in enumerate(x_ops)];
166+
[c(_tf) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
164167
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
165168
end
166169

167-
exprs = substitute_fixed_t_vars(exprs)
168-
169-
# for variables like x(t)
170+
exprs = substitute_fixed_t_vars(m, sys, exprs)
170171
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
171172
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
172173
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
173-
exprs
174174
end
175175

176-
function substitute_fixed_t_vars(exprs)
176+
function substitute_fixed_t_vars(model::CasADiModel, sys, exprs)
177+
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
178+
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
179+
iv = MTK.get_iv(sys)
177180
for i in 1:length(exprs)
178181
subvars = MTK.vars(exprs[i])
179182
for st in subvars
@@ -183,26 +186,27 @@ function substitute_fixed_t_vars(exprs)
183186
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
184187
if haskey(stidxmap, x(iv))
185188
idx = stidxmap[x(iv)]
186-
cv = U
189+
cv = model.U
187190
else
188191
idx = ctidxmap[x(iv)]
189-
cv = V
192+
cv = model.V
190193
end
191194
exprs[i] = Symbolics.fast_substitute(exprs[i], Dict(x(t) => cv(t)[idx]))
192195
end
193196
end
197+
exprs
194198
end
195199

196-
MTK.substitute_differentials(model::CasADiModel, exprs, args...) = exprs
200+
MTK.substitute_differentials(model::CasADiModel, sys, eqs) = exprs
197201

198-
function MTK.substitute_integral(model::CasADiModel, exprs)
199-
@unpack U, opti = model
202+
function MTK.substitute_integral(m::CasADiModel, exprs, tspan)
203+
@unpack U, model, tₛ = m
200204
dt = U.t[2] - U.t[1]
201205
intmap = Dict()
202206
for int in MTK.collect_applied_operators(exprs, Symbolics.Integral)
203207
op = MTK.operation(int)
204208
arg = only(arguments(MTK.value(int)))
205-
lo, hi = (op.domain.domain.left, op.domain.domain.right)
209+
lo, hi = MTK.value.((op.domain.domain.left, op.domain.domain.right))
206210
!isequal((lo, hi), tspan) &&
207211
error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
208212
# Approximate integral as sum.
@@ -212,11 +216,11 @@ function MTK.substitute_integral(model::CasADiModel, exprs)
212216
exprs = MTK.value.(exprs)
213217
end
214218

215-
function add_solve_constraints!(prob, tableau)
219+
function add_solve_constraints!(prob::CasADiDynamicOptProblem, tableau)
216220
@unpack A, α, c = tableau
217221
@unpack model, f, p = prob
218-
@unpack opti, U, V, tₛ = model
219-
solver_opti = copy(opti)
222+
@unpack model, U, V, tₛ = model
223+
solver_opti = copy(model)
220224

221225
tsteps = U.t
222226
dt = tsteps[2] - tsteps[1]
@@ -257,29 +261,56 @@ function add_solve_constraints!(prob, tableau)
257261
solver_opti
258262
end
259263

260-
function MTK.prepare_solver()
261-
opti = add_solve_constraints(prob, tableau)
262-
solver!(opti, "$solver", plugin_options, solver_options)
264+
"""
265+
CasADi Collocation solver.
266+
- solver: an optimization solver such as Ipopt. Should be given as a string or symbol in all lowercase, e.g. "ipopt"
267+
- tableau: An ODE RK tableau. Load a tableau by calling a function like `constructRK4` and may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. If this argument is not passed in, the solver will default to Radau second order.
268+
"""
269+
struct CasADiCollocation <: AbstractCollocation
270+
solver::Union{String, Symbol}
271+
tableau::DiffEqBase.ODERKTableau
272+
end
273+
MTK.CasADiCollocation(solver, tableau = MTK.constructDefault()) = CasADiCollocation(solver, tableau)
274+
275+
function MTK.prepare_and_optimize!(prob::CasADiDynamicOptProblem, solver::CasADiCollocation; verbose = false, solver_options = Dict(), plugin_options = Dict(), kwargs...)
276+
solver_opti = add_solve_constraints!(prob, solver.tableau)
277+
verbose || (solver_options["print_level"] = 0)
278+
solver!(solver_opti, "$(solver.solver)", plugin_options, solver_options)
279+
try
280+
CasADi.solve!(solver_opti)
281+
catch ErrorException
282+
end
283+
prob.model.solver_opti = solver_opti
263284
end
264-
function MTK.get_U_values()
265-
U_vals = value_getter(U.u)
285+
286+
function MTK.get_U_values(model::CasADiModel)
287+
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
288+
(nu, nt) = size(model.U.u)
289+
U_vals = value_getter(model.solver_opti, model.U.u)
266290
size(U_vals, 2) == 1 && (U_vals = U_vals')
267-
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
291+
U_vals = [[U_vals[i, j] for i in 1:nu] for j in 1:nt]
268292
end
269-
function MTK.get_V_values()
293+
294+
function MTK.get_V_values(model::CasADiModel)
295+
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
296+
(nu, nt) = size(model.V.u)
297+
if nu*nt != 0
298+
V_vals = value_getter(model.solver_opti, model.V.u)
299+
size(V_vals, 2) == 1 && (V_vals = V_vals')
300+
V_vals = [[V_vals[i, j] for i in 1:nu] for j in 1:nt]
301+
else
302+
nothing
303+
end
270304
end
271-
function MTK.get_t_values()
272-
ts = value_getter(tₛ) * U.t
305+
306+
function MTK.get_t_values(model::CasADiModel)
307+
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
308+
ts = value_getter(model.solver_opti, model.tₛ) .* model.U.t
273309
end
274310

275-
function MTK.optimize_model!()
276-
try
277-
sol = CasADi.solve!(opti)
278-
value_getter = x -> CasADi.value(sol, x)
279-
catch ErrorException
280-
value_getter = x -> CasADi.debug_value(opti, x)
281-
failed = true
282-
end
311+
function MTK.successful_solve(m::CasADiModel)
312+
isnothing(m.solver_opti) && return false
313+
retcode = CasADi.return_status(m.solver_opti)
314+
retcode == "Solve_Succeeded" || retcode == "Solved_To_Acceptable_Level"
283315
end
284-
MTK.successful_solve() = true
285316
end

0 commit comments

Comments
 (0)