Skip to content

Commit 4ca744a

Browse files
committed
correctly implement interface
1 parent 912d71b commit 4ca744a

File tree

4 files changed

+164
-127
lines changed

4 files changed

+164
-127
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 103 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ struct MXLinearInterpolation
1818
dt::Float64
1919
end
2020

21-
struct CasADiModel
22-
opti::Opti
21+
mutable struct CasADiModel
22+
model::Opti
2323
U::MXLinearInterpolation
2424
V::MXLinearInterpolation
2525
tₛ::MX
2626
is_free_final::Bool
27+
solver_opti::Union{Nothing, Opti}
28+
29+
function CasADiModel(opti, U, V, tₛ, is_free_final, solver_opti = nothing)
30+
new(opti, U, V, tₛ, is_free_final, solver_opti)
31+
end
2732
end
2833

2934
struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
@@ -74,24 +79,27 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
7479
dt = nothing,
7580
steps = nothing,
7681
guesses = Dict(), kwargs...)
77-
process_DynamicOptProblem(CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs...)
82+
MTK.process_DynamicOptProblem(CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs...)
7883
end
7984

80-
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.opti()
85+
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.Opti()
86+
MTK.generate_time_variable!(opti::Opti, args...) = nothing
8187

82-
function MTK.generate_state_variable(model::Opti, u0, ns, nt, tsteps)
88+
function MTK.generate_state_variable!(model::Opti, u0, ns, tsteps)
89+
nt = length(tsteps)
8390
U = CasADi.variable!(model, ns, nt)
84-
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
91+
set_initial!(model, U, DM(repeat(u0, 1, nt)))
8592
MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
8693
end
8794

88-
function MTK.generate_input_variable(model::Opti, c0, nc, nt, tsteps)
95+
function MTK.generate_input_variable!(model::Opti, c0, nc, tsteps)
96+
nt = length(tsteps)
8997
V = CasADi.variable!(model, nc, nt)
90-
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))
98+
!isempty(c0) && set_initial!(model, V, DM(repeat(c0, 1, nt)))
9199
MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
92100
end
93101

94-
function MTK.generate_timescale(model::Opti, guess, is_free_t)
102+
function MTK.generate_timescale!(model::Opti, guess, is_free_t)
95103
if is_free_t
96104
tₛ = variable!(model)
97105
set_initial!(model, tₛ, guess)
@@ -102,78 +110,73 @@ function MTK.generate_timescale(model::Opti, guess, is_free_t)
102110
end
103111
end
104112

105-
function MTK.add_constraint!(model::CasADiModel, expr)
106-
@unpack opti = model
107-
if cons isa Equation
108-
subject_to!(opti, expr.lhs - expr.rhs == 0)
109-
elseif cons.relational_op === Symbolics.geq
110-
subject_to!(opti, expr.lhs - expr.rhs 0)
113+
function MTK.add_constraint!(m::CasADiModel, expr)
114+
if expr isa Equation
115+
subject_to!(m.model, expr.lhs - expr.rhs == 0)
116+
elseif expr.relational_op === Symbolics.geq
117+
subject_to!(m.model, expr.lhs - expr.rhs 0)
111118
else
112-
subject_to!(opti, expr.lhs - expr.rhs 0)
119+
subject_to!(m.model, expr.lhs - expr.rhs 0)
113120
end
114121
end
115-
MTK.set_objective!(model::CasADiModel, expr) = minimize!(model.opti, MX(expr))
122+
MTK.set_objective!(m::CasADiModel, expr) = minimize!(m.model, MX(expr))
116123

117-
function MTK.set_variable_bounds!(model, sys, pmap, tf)
118-
@unpack opti, U, V = model
124+
function MTK.set_variable_bounds!(m::CasADiModel, sys, pmap, tf)
125+
@unpack model, U, tₛ, V = m
119126
for (i, u) in enumerate(unknowns(sys))
120127
if MTK.hasbounds(u)
121128
lo, hi = MTK.getbounds(u)
122-
subject_to!(opti, Symbolics.fast_substitute(lo, pmap) <= U.u[i, :])
123-
subject_to!(opti, U.u[i, :] <= Symbolics.fast_substitute(hi, pmap))
129+
subject_to!(model, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
130+
subject_to!(model, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
124131
end
125132
end
126133
for (i, v) in enumerate(MTK.unbound_inputs(sys))
127134
if MTK.hasbounds(v)
128135
lo, hi = MTK.getbounds(v)
129-
subject_to!(opti, Symbolics.fast_substitute(lo, pmap) <= V.u[i, :])
130-
subject_to!(opti, V.u[i, :] <= Symbolics.fast_substitute(hi, pmap))
136+
subject_to!(model, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
137+
subject_to!(model, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
131138
end
132139
end
133140
if MTK.symbolic_type(tf) === MTK.ScalarSymbolic() && hasbounds(tf)
134141
lo, hi = MTK.getbounds(tf)
135-
subject_to!(opti, model.tₛ >= lo)
136-
subject_to!(opti, model.tₛ <= hi)
142+
subject_to!(model, tₛ >= lo)
143+
subject_to!(model, tₛ <= hi)
137144
end
138145
end
139146

140-
function MTK.add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
141-
@unpack opti, U = model
147+
function MTK.add_initial_constraints!(m::CasADiModel, u0, u0_idxs, args...)
148+
@unpack model, U = m
142149
for i in u0_idxs
143-
subject_to!(opti, U.u[i, 1] == u0[i])
150+
subject_to!(model, U.u[i, 1] == u0[i])
144151
end
145152
end
146153

147-
function MTK.substitute_model_vars(
148-
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
149-
@unpack opti, U, V, tₛ = model
154+
function MTK.substitute_model_vars(m::CasADiModel, sys, exprs, tspan)
155+
@unpack model, U, V, tₛ = m
150156
iv = MTK.get_iv(sys)
151157
sts = unknowns(sys)
152158
cts = MTK.unbound_inputs(sys)
153-
154159
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
155160
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
156-
157-
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
158-
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
159-
# tf means different things in different contexts; a [tf] in a cost function
160-
# should be tₛ, while a x(tf) should translate to x[1]
161-
if is_free_t
162-
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
163-
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
161+
(ti, tf) = tspan
162+
if MTK.is_free_final(m)
163+
_tf = tₛ + ti
164+
exprs = map(c -> Symbolics.fast_substitute(c, Dict(tf => _tf)), exprs)
165+
free_t_map = Dict([[x(_tf) => U.u[i, end] for (i, x) in enumerate(x_ops)];
166+
[c(_tf) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
164167
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
165168
end
166169

167-
exprs = substitute_fixed_t_vars(exprs)
168-
169-
# for variables like x(t)
170+
exprs = substitute_fixed_t_vars(m, sys, exprs)
170171
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
171172
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
172173
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
173-
exprs
174174
end
175175

176-
function substitute_fixed_t_vars(exprs)
176+
function substitute_fixed_t_vars(model::CasADiModel, sys, exprs)
177+
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
178+
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
179+
iv = MTK.get_iv(sys)
177180
for i in 1:length(exprs)
178181
subvars = MTK.vars(exprs[i])
179182
for st in subvars
@@ -183,27 +186,28 @@ function substitute_fixed_t_vars(exprs)
183186
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
184187
if haskey(stidxmap, x(iv))
185188
idx = stidxmap[x(iv)]
186-
cv = U
189+
cv = model.U
187190
else
188191
idx = ctidxmap[x(iv)]
189-
cv = V
192+
cv = model.V
190193
end
191194
exprs[i] = Symbolics.fast_substitute(exprs[i], Dict(x(t) => cv(t)[idx]))
192195
end
193196
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => cv(t)[idx]))
194197
end
198+
exprs
195199
end
196200

197-
MTK.substitute_differentials(model::CasADiModel, exprs, args...) = exprs
201+
MTK.substitute_differentials(model::CasADiModel, sys, eqs) = exprs
198202

199-
function MTK.substitute_integral(model::CasADiModel, exprs)
200-
@unpack U, opti = model
203+
function MTK.substitute_integral(m::CasADiModel, exprs, tspan)
204+
@unpack U, model, tₛ = m
201205
dt = U.t[2] - U.t[1]
202206
intmap = Dict()
203207
for int in MTK.collect_applied_operators(exprs, Symbolics.Integral)
204208
op = MTK.operation(int)
205209
arg = only(arguments(MTK.value(int)))
206-
lo, hi = (op.domain.domain.left, op.domain.domain.right)
210+
lo, hi = MTK.value.((op.domain.domain.left, op.domain.domain.right))
207211
!isequal((lo, hi), tspan) &&
208212
error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
209213
# Approximate integral as sum.
@@ -213,11 +217,11 @@ function MTK.substitute_integral(model::CasADiModel, exprs)
213217
exprs = MTK.value.(exprs)
214218
end
215219

216-
function add_solve_constraints!(prob, tableau)
220+
function add_solve_constraints!(prob::CasADiDynamicOptProblem, tableau)
217221
@unpack A, α, c = tableau
218222
@unpack model, f, p = prob
219-
@unpack opti, U, V, tₛ = model
220-
solver_opti = copy(opti)
223+
@unpack model, U, V, tₛ = model
224+
solver_opti = copy(model)
221225

222226
tsteps = U.t
223227
dt = tsteps[2] - tsteps[1]
@@ -258,29 +262,56 @@ function add_solve_constraints!(prob, tableau)
258262
solver_opti
259263
end
260264

261-
function MTK.prepare_solver()
262-
opti = add_solve_constraints(prob, tableau)
263-
solver!(opti, "$solver", plugin_options, solver_options)
265+
"""
266+
CasADi Collocation solver.
267+
- solver: an optimization solver such as Ipopt. Should be given as a string or symbol in all lowercase, e.g. "ipopt"
268+
- tableau: An ODE RK tableau. Load a tableau by calling a function like `constructRK4` and may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. If this argument is not passed in, the solver will default to Radau second order.
269+
"""
270+
struct CasADiCollocation <: AbstractCollocation
271+
solver::Union{String, Symbol}
272+
tableau::DiffEqBase.ODERKTableau
273+
end
274+
MTK.CasADiCollocation(solver, tableau = MTK.constructDefault()) = CasADiCollocation(solver, tableau)
275+
276+
function MTK.prepare_and_optimize!(prob::CasADiDynamicOptProblem, solver::CasADiCollocation; verbose = false, solver_options = Dict(), plugin_options = Dict(), kwargs...)
277+
solver_opti = add_solve_constraints!(prob, solver.tableau)
278+
verbose || (solver_options["print_level"] = 0)
279+
solver!(solver_opti, "$(solver.solver)", plugin_options, solver_options)
280+
try
281+
CasADi.solve!(solver_opti)
282+
catch ErrorException
283+
end
284+
prob.model.solver_opti = solver_opti
264285
end
265-
function MTK.get_U_values()
266-
U_vals = value_getter(U.u)
286+
287+
function MTK.get_U_values(model::CasADiModel)
288+
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
289+
(nu, nt) = size(model.U.u)
290+
U_vals = value_getter(model.solver_opti, model.U.u)
267291
size(U_vals, 2) == 1 && (U_vals = U_vals')
268-
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
292+
U_vals = [[U_vals[i, j] for i in 1:nu] for j in 1:nt]
269293
end
270-
function MTK.get_V_values()
294+
295+
function MTK.get_V_values(model::CasADiModel)
296+
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
297+
(nu, nt) = size(model.V.u)
298+
if nu*nt != 0
299+
V_vals = value_getter(model.solver_opti, model.V.u)
300+
size(V_vals, 2) == 1 && (V_vals = V_vals')
301+
V_vals = [[V_vals[i, j] for i in 1:nu] for j in 1:nt]
302+
else
303+
nothing
304+
end
271305
end
272-
function MTK.get_t_values()
273-
ts = value_getter(tₛ) * U.t
306+
307+
function MTK.get_t_values(model::CasADiModel)
308+
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
309+
ts = value_getter(model.solver_opti, model.tₛ) .* model.U.t
274310
end
275311

276-
function MTK.optimize_model!()
277-
try
278-
sol = CasADi.solve!(opti)
279-
value_getter = x -> CasADi.value(sol, x)
280-
catch ErrorException
281-
value_getter = x -> CasADi.debug_value(opti, x)
282-
failed = true
283-
end
312+
function MTK.successful_solve(m::CasADiModel)
313+
isnothing(m.solver_opti) && return false
314+
retcode = CasADi.return_status(m.solver_opti)
315+
retcode == "Solve_Succeeded" || retcode == "Solved_To_Acceptable_Level"
284316
end
285-
MTK.successful_solve() = true
286317
end

0 commit comments

Comments
 (0)