Skip to content

Commit 4d45266

Browse files
committed
refactor: add interface functions for CasADi
1 parent 9ccf898 commit 4d45266

File tree

5 files changed

+206
-298
lines changed

5 files changed

+206
-298
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 86 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ using UnPack
66
using NaNMath
77
const MTK = ModelingToolkit
88

9-
# NaNMath
109
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
1110
f = nameof(ff)
12-
# These need to be defined so that JuMP can trace through functions built by Symbolics
1311
@eval NaNMath.$f(x::CasadiSymbolicObject) = Base.$f(x)
1412
end
1513

@@ -76,75 +74,47 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
7674
dt = nothing,
7775
steps = nothing,
7876
guesses = Dict(), kwargs...)
79-
MTK.warn_overdetermined(sys, u0map)
80-
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
81-
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
82-
t = tspan !== nothing ? tspan[1] : tspan, output_type = MX, kwargs...)
83-
84-
pmap = Dict{Any, Any}(pmap)
85-
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
86-
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
87-
88-
CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
77+
process_DynamicOptProblem(CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs...)
8978
end
9079

9180
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.opti()
92-
MTK.generate_state_variable(model, u0, ns, nt)
93-
MTK.generate_input_variable(model, c0, nc, nt) = 1
94-
MTK.generate_timescale(model, dims) = 1
9581

96-
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
97-
ctrls = MTK.unbound_inputs(sys)
98-
states = unknowns(sys)
99-
opti = CasADi.Opti()
82+
function MTK.generate_state_variable(model::Opti, u0, ns, nt, tsteps)
83+
U = CasADi.variable!(model, ns, nt)
84+
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
85+
MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
86+
end
87+
88+
function MTK.generate_input_variable(model::Opti, c0, nc, nt, tsteps)
89+
V = CasADi.variable!(model, nc, nt)
90+
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))
91+
MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
92+
end
10093

94+
function MTK.generate_timescale(model::Opti, guess, is_free_t)
10195
if is_free_t
102-
(ts_sym, te_sym) = tspan
103-
MTK.symbolic_type(ts_sym) !== MTK.NotSymbolic() &&
104-
error("Free initial time problems are not currently supported in CasADiDynamicOptProblem.")
105-
tₛ = variable!(opti)
106-
set_initial!(opti, tₛ, pmap[te_sym])
107-
subject_to!(opti, tₛ >= ts_sym)
108-
hasbounds(te_sym) && begin
109-
lo, hi = getbounds(te_sym)
110-
subject_to!(opti, tₛ >= lo)
111-
subject_to!(opti, tₛ >= hi)
112-
end
113-
pmap[te_sym] = tₛ
114-
tsteps = LinRange(0, 1, steps)
96+
tₛ = variable!(model)
97+
set_initial!(model, tₛ, guess)
98+
subject_to!(model, tₛ >= 0)
99+
tₛ
115100
else
116-
tₛ = MX(1)
117-
tsteps = LinRange(tspan[1], tspan[2], steps)
101+
MX(1)
118102
end
103+
end
119104

120-
U = CasADi.variable!(opti, length(states), steps)
121-
V = CasADi.variable!(opti, length(ctrls), steps)
122-
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
123-
c0 = MTK.value.([pmap[c] for c in ctrls])
124-
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))
125-
126-
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
127-
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
128-
for (i, ct) in enumerate(ctrls)
129-
pmap[ct] = V[i, :]
105+
function MTK.add_constraint!(model::CasADiModel, expr)
106+
@unpack opti = model
107+
if cons isa Equation
108+
subject_to!(opti, expr.lhs - expr.rhs == 0)
109+
elseif cons.relational_op === Symbolics.geq
110+
subject_to!(opti, expr.lhs - expr.rhs 0)
111+
else
112+
subject_to!(opti, expr.lhs - expr.rhs 0)
130113
end
131-
132-
model = CasADiModel(opti, U_interp, V_interp, tₛ)
133-
134-
set_casadi_bounds!(model, sys, pmap)
135-
add_cost_function!(model, sys, tspan, pmap; is_free_t)
136-
add_user_constraints!(model, sys, tspan, pmap; is_free_t)
137-
138-
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
139-
u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map])
140-
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
141-
[stidxmap[MTK.default_toterm(k)] for (k, v) in u0map]
142-
add_initial_constraints!(model, u0, u0_idxs)
143-
144-
model
145114
end
115+
MTK.set_objective!(model::CasADiModel, expr) = minimize!(model.opti, MX(expr))
146116

147-
function set_casadi_bounds!(model, sys, pmap)
117+
function MTK.set_variable_bounds!(model, sys, pmap, tf)
148118
@unpack opti, U, V = model
149119
for (i, u) in enumerate(unknowns(sys))
150120
if MTK.hasbounds(u)
@@ -160,75 +130,53 @@ function set_casadi_bounds!(model, sys, pmap)
160130
subject_to!(opti, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
161131
end
162132
end
133+
if MTK.symbolic_type(tf) === MTK.ScalarSymbolic() && hasbounds(tf)
134+
lo, hi = MTK.getbounds(tf)
135+
subject_to!(opti, model.tₛ >= lo)
136+
subject_to!(opti, model.tₛ <= hi)
137+
end
163138
end
164139

165-
function add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
140+
function MTK.add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
166141
@unpack opti, U = model
167142
for i in u0_idxs
168143
subject_to!(opti, U.u[i, 1] == u0[i])
169144
end
170145
end
171146

172-
function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
147+
function MTK.substitute_model_vars(
148+
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
173149
@unpack opti, U, V, tₛ = model
174-
175150
iv = MTK.get_iv(sys)
176-
conssys = MTK.get_constraintsystem(sys)
177-
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
178-
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
179-
180-
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
181-
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
182-
cons_unknowns = map(MTK.default_toterm, unknowns(conssys))
183-
184-
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
185-
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints; is_free_t, auxmap)
186-
# Manually substitute fixed-t variables
187-
for (i, cons) in enumerate(jconstraints)
188-
consvars = MTK.vars(cons)
189-
for st in consvars
190-
MTK.iscall(st) || continue
191-
x = MTK.operation(st)
192-
t = only(MTK.arguments(st))
193-
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
194-
if haskey(stidxmap, x(iv))
195-
idx = stidxmap[x(iv)]
196-
cv = U
197-
else
198-
idx = ctidxmap[x(iv)]
199-
cv = V
200-
end
201-
cons = Symbolics.substitute(cons, Dict(x(t) => cv(t)[idx]))
202-
end
151+
sts = unknowns(sys)
152+
cts = MTK.unbound_inputs(sys)
203153

204-
if cons isa Equation
205-
subject_to!(opti, cons.lhs - cons.rhs == 0)
206-
elseif cons.relational_op === Symbolics.geq
207-
subject_to!(opti, cons.lhs - cons.rhs 0)
208-
else
209-
subject_to!(opti, cons.lhs - cons.rhs 0)
210-
end
211-
end
212-
end
154+
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
155+
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
213156

214-
function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
215-
@unpack opti, U, V, tₛ = model
216-
jcosts = copy(MTK.get_costs(sys))
217-
consolidate = MTK.get_consolidate(sys)
218-
if isnothing(jcosts) || isempty(jcosts)
219-
minimize!(opti, MX(0))
220-
return
157+
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
158+
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
159+
# tf means different things in different contexts; a [tf] in a cost function
160+
# should be tₛ, while a x(tf) should translate to x[1]
161+
if is_free_t
162+
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
163+
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
164+
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
221165
end
222166

223-
iv = MTK.get_iv(sys)
224-
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
225-
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
226-
227-
jcosts = substitute_casadi_vars(model, sys, pmap, jcosts; is_free_t)
228-
# Substitute fixed-time variables.
229-
for i in 1:length(jcosts)
230-
costvars = MTK.vars(jcosts[i])
231-
for st in costvars
167+
exprs = substitute_fixed_t_vars(exprs)
168+
169+
# for variables like x(t)
170+
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
171+
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
172+
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
173+
exprs
174+
end
175+
176+
function substitute_fixed_t_vars(exprs)
177+
for i in 1:length(exprs)
178+
subvars = MTK.vars(exprs[i])
179+
for st in subvars
232180
MTK.iscall(st) || continue
233181
x = operation(st)
234182
t = only(arguments(st))
@@ -240,13 +188,18 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
240188
idx = ctidxmap[x(iv)]
241189
cv = V
242190
end
243-
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => cv(t)[idx]))
191+
exprs[i] = Symbolics.fast_substitute(exprs[i], Dict(x(t) => cv(t)[idx]))
244192
end
245193
end
194+
end
195+
196+
MTK.substitute_differentials(model::CasADiModel, exprs, args...) = exprs
246197

198+
function MTK.substitute_integral(model::CasADiModel, exprs)
199+
@unpack U, opti = model
247200
dt = U.t[2] - U.t[1]
248201
intmap = Dict()
249-
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
202+
for int in MTK.collect_applied_operators(exprs, Symbolics.Integral)
250203
op = MTK.operation(int)
251204
arg = only(arguments(MTK.value(int)))
252205
lo, hi = (op.domain.domain.left, op.domain.domain.right)
@@ -255,39 +208,11 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
255208
# Approximate integral as sum.
256209
intmap[int] = dt * tₛ * sum(arg)
257210
end
258-
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
259-
jcosts = MTK.value.(jcosts)
260-
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
211+
exprs = map(c -> Symbolics.substitute(c, intmap), exprs)
212+
exprs = MTK.value.(exprs)
261213
end
262214

263-
function substitute_casadi_vars(
264-
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
265-
@unpack opti, U, V, tₛ = model
266-
iv = MTK.get_iv(sys)
267-
sts = unknowns(sys)
268-
cts = MTK.unbound_inputs(sys)
269-
270-
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
271-
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
272-
273-
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
274-
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
275-
# tf means different things in different contexts; a [tf] in a cost function
276-
# should be tₛ, while a x(tf) should translate to x[1]
277-
if is_free_t
278-
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
279-
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
280-
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
281-
end
282-
283-
# for variables like x(t)
284-
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
285-
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
286-
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
287-
exprs
288-
end
289-
290-
function add_solve_constraints(prob, tableau)
215+
function add_solve_constraints!(prob, tableau)
291216
@unpack A, α, c = tableau
292217
@unpack model, f, p = prob
293218
@unpack opti, U, V, tₛ = model
@@ -332,57 +257,29 @@ function add_solve_constraints(prob, tableau)
332257
solver_opti
333258
end
334259

335-
"""
336-
solve(prob::CasADiDynamicOptProblem, casadi_solver, ode_solver; plugin_options, solver_options, silent)
337-
338-
`plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
339-
340-
NOTE: the solver should be passed in as a string to CasADi. "ipopt"
341-
"""
342-
function DiffEqBase.solve(
343-
prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt",
344-
tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(),
345-
solver_options::Dict = Dict(), silent = false)
346-
@unpack model, u0, p, tspan, f = prob
347-
tableau = tableau_getter()
348-
@unpack opti, U, V, tₛ = model
349-
260+
function MTK.prepare_solver()
350261
opti = add_solve_constraints(prob, tableau)
351-
silent && (solver_options["print_level"] = 0)
352262
solver!(opti, "$solver", plugin_options, solver_options)
263+
end
264+
function MTK.get_U_values()
265+
U_vals = value_getter(U.u)
266+
size(U_vals, 2) == 1 && (U_vals = U_vals')
267+
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
268+
end
269+
function MTK.get_V_values()
270+
end
271+
function MTK.get_t_values()
272+
ts = value_getter(tₛ) * U.t
273+
end
353274

354-
failed = false
355-
value_getter = nothing
356-
sol = nothing
275+
function MTK.optimize_model!()
357276
try
358277
sol = CasADi.solve!(opti)
359278
value_getter = x -> CasADi.value(sol, x)
360279
catch ErrorException
361280
value_getter = x -> CasADi.debug_value(opti, x)
362281
failed = true
363282
end
364-
365-
ts = value_getter(tₛ) * U.t
366-
U_vals = value_getter(U.u)
367-
size(U_vals, 2) == 1 && (U_vals = U_vals')
368-
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
369-
ode_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, U_vals)
370-
371-
input_sol = nothing
372-
if prod(size(V.u)) != 0
373-
V_vals = value_getter(V.u)
374-
size(V_vals, 2) == 1 && (V_vals = V_vals')
375-
V_vals = [[V_vals[i, j] for i in 1:size(V_vals, 1)] for j in 1:length(ts)]
376-
input_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, V_vals)
377-
end
378-
379-
if failed
380-
ode_sol = SciMLBase.solution_new_retcode(
381-
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
382-
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
383-
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
384-
end
385-
386-
DynamicOptSolution(model, ode_sol, input_sol)
387283
end
284+
MTK.successful_solve() = true
388285
end

0 commit comments

Comments
 (0)