Skip to content

Commit 3308a1a

Browse files
fix: update CasADi extension to new semantics
1 parent e86dd15 commit 3308a1a

File tree

1 file changed

+27
-30
lines changed

1 file changed

+27
-30
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ function (M::MXLinearInterpolation)(τ)
5656
end
5757

5858
"""
59-
CasADiDynamicOptProblem(sys::ODESystem, u0, tspan, p; dt, steps)
59+
CasADiDynamicOptProblem(sys::System, u0, tspan, p; dt, steps)
6060
61-
Convert an ODESystem representing an optimal control system into a CasADi model
61+
Convert an System representing an optimal control system into a CasADi model
6262
for solving using optimization. Must provide either `dt`, the timestep between collocation
6363
points (which, along with the timespan, determines the number of points), or directly
6464
provide the number of points as `steps`.
@@ -68,10 +68,10 @@ The optimization variables:
6868
- a vector-of-vectors V representing the controls as an interpolation array
6969
7070
The constraints are:
71-
- The set of user constraints passed to the ODESystem via `constraints`
71+
- The set of user constraints passed to the System via `constraints`
7272
- The solver constraints that encode the time-stepping used by the solver
7373
"""
74-
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
74+
function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
7575
dt = nothing,
7676
steps = nothing,
7777
guesses = Dict(), kwargs...)
@@ -167,15 +167,15 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
167167
@unpack opti, U, V, tₛ = model
168168

169169
iv = MTK.get_iv(sys)
170-
conssys = MTK.get_constraintsystem(sys)
171-
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
170+
jconstraints = MTK.get_constraints(sys)
172171
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
173172

174173
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
175174
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
176-
cons_unknowns = map(MTK.default_toterm, unknowns(conssys))
175+
cons_dvs, cons_ps = MTK.process_constraint_system(
176+
jconstraints, Set(unknowns(sys)), parameters(sys), iv)
177177

178-
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
178+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in cons_dvs])
179179
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints; is_free_t, auxmap)
180180
# Manually substitute fixed-t variables
181181
for (i, cons) in enumerate(jconstraints)
@@ -207,9 +207,8 @@ end
207207

208208
function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
209209
@unpack opti, U, V, tₛ = model
210-
jcosts = copy(MTK.get_costs(sys))
211-
consolidate = MTK.get_consolidate(sys)
212-
if isnothing(jcosts) || isempty(jcosts)
210+
jcosts = cost(sys)
211+
if Symbolics._iszero(jcosts)
213212
minimize!(opti, MX(0))
214213
return
215214
end
@@ -218,24 +217,22 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
218217
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
219218
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
220219

221-
jcosts = substitute_casadi_vars(model, sys, pmap, jcosts; is_free_t)
220+
jcosts = substitute_casadi_vars(model, sys, pmap, [jcosts]; is_free_t)[1]
222221
# Substitute fixed-time variables.
223-
for i in 1:length(jcosts)
224-
costvars = MTK.vars(jcosts[i])
225-
for st in costvars
226-
MTK.iscall(st) || continue
227-
x = operation(st)
228-
t = only(arguments(st))
229-
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
230-
if haskey(stidxmap, x(iv))
231-
idx = stidxmap[x(iv)]
232-
cv = U
233-
else
234-
idx = ctidxmap[x(iv)]
235-
cv = V
236-
end
237-
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => cv(t)[idx]))
222+
costvars = MTK.vars(jcosts)
223+
for st in costvars
224+
MTK.iscall(st) || continue
225+
x = operation(st)
226+
t = only(arguments(st))
227+
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
228+
if haskey(stidxmap, x(iv))
229+
idx = stidxmap[x(iv)]
230+
cv = U
231+
else
232+
idx = ctidxmap[x(iv)]
233+
cv = V
238234
end
235+
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => cv(t)[idx]))
239236
end
240237

241238
dt = U.t[2] - U.t[1]
@@ -249,9 +246,9 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
249246
# Approximate integral as sum.
250247
intmap[int] = dt * tₛ * sum(arg)
251248
end
252-
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
253-
jcosts = MTK.value.(jcosts)
254-
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
249+
jcosts = Symbolics.substitute(jcosts, intmap)
250+
jcosts = MTK.value(jcosts)
251+
minimize!(opti, MX(jcosts))
255252
end
256253

257254
function substitute_casadi_vars(

0 commit comments

Comments
 (0)