Skip to content

Commit 912d71b

Browse files
committed
refactor: add interface functions for CasADi
1 parent df1a0e4 commit 912d71b

File tree

5 files changed

+211
-302
lines changed

5 files changed

+211
-302
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 91 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ using UnPack
66
using NaNMath
77
const MTK = ModelingToolkit
88

9-
# NaNMath
109
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
1110
f = nameof(ff)
12-
# These need to be defined so that JuMP can trace through functions built by Symbolics
1311
@eval NaNMath.$f(x::CasadiSymbolicObject) = Base.$f(x)
1412
end
1513

@@ -76,78 +74,47 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
7674
dt = nothing,
7775
steps = nothing,
7876
guesses = Dict(), kwargs...)
79-
MTK.warn_overdetermined(sys, u0map)
80-
_u0map = has_alg_eqs(sys) ? MTK.to_varmap(u0map, unknowns(sys)) :
81-
merge(Dict(u0map), Dict(guesses))
82-
pmap = MTK.to_varmap(pmap, parameters(sys))
83-
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, merge(_u0map, pmap);
84-
t = tspan !== nothing ? tspan[1] : tspan, output_type = MX, kwargs...)
85-
86-
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
87-
MTK.evaluate_varmap!(pmap, keys(pmap))
88-
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
89-
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
90-
91-
CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
77+
process_DynamicOptProblem(CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs...)
9278
end
9379

9480
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.opti()
95-
MTK.generate_state_variable(model, u0, ns, nt)
96-
MTK.generate_input_variable(model, c0, nc, nt) = 1
97-
MTK.generate_timescale(model, dims) = 1
9881

99-
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
100-
ctrls = MTK.unbound_inputs(sys)
101-
states = unknowns(sys)
102-
opti = CasADi.Opti()
82+
function MTK.generate_state_variable(model::Opti, u0, ns, nt, tsteps)
83+
U = CasADi.variable!(model, ns, nt)
84+
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
85+
MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
86+
end
10387

88+
function MTK.generate_input_variable(model::Opti, c0, nc, nt, tsteps)
89+
V = CasADi.variable!(model, nc, nt)
90+
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))
91+
MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
92+
end
93+
94+
function MTK.generate_timescale(model::Opti, guess, is_free_t)
10495
if is_free_t
105-
(ts_sym, te_sym) = tspan
106-
MTK.symbolic_type(ts_sym) !== MTK.NotSymbolic() &&
107-
error("Free initial time problems are not currently supported in CasADiDynamicOptProblem.")
108-
tₛ = variable!(opti)
109-
set_initial!(opti, tₛ, pmap[te_sym])
110-
subject_to!(opti, tₛ >= ts_sym)
111-
hasbounds(te_sym) && begin
112-
lo, hi = getbounds(te_sym)
113-
subject_to!(opti, tₛ >= lo)
114-
subject_to!(opti, tₛ >= hi)
115-
end
116-
pmap[te_sym] = tₛ
117-
tsteps = LinRange(0, 1, steps)
96+
tₛ = variable!(model)
97+
set_initial!(model, tₛ, guess)
98+
subject_to!(model, tₛ >= 0)
99+
tₛ
118100
else
119-
tₛ = MX(1)
120-
tsteps = LinRange(tspan[1], tspan[2], steps)
101+
MX(1)
121102
end
103+
end
122104

123-
U = CasADi.variable!(opti, length(states), steps)
124-
V = CasADi.variable!(opti, length(ctrls), steps)
125-
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
126-
c0 = MTK.value.([pmap[c] for c in ctrls])
127-
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))
128-
129-
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
130-
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
131-
for (i, ct) in enumerate(ctrls)
132-
pmap[ct] = V[i, :]
105+
function MTK.add_constraint!(model::CasADiModel, expr)
106+
@unpack opti = model
107+
if cons isa Equation
108+
subject_to!(opti, expr.lhs - expr.rhs == 0)
109+
elseif cons.relational_op === Symbolics.geq
110+
subject_to!(opti, expr.lhs - expr.rhs 0)
111+
else
112+
subject_to!(opti, expr.lhs - expr.rhs 0)
133113
end
134-
135-
model = CasADiModel(opti, U_interp, V_interp, tₛ)
136-
137-
set_casadi_bounds!(model, sys, pmap)
138-
add_cost_function!(model, sys, tspan, pmap; is_free_t)
139-
add_user_constraints!(model, sys, tspan, pmap; is_free_t)
140-
141-
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
142-
u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map])
143-
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
144-
[stidxmap[MTK.default_toterm(k)] for (k, v) in u0map]
145-
add_initial_constraints!(model, u0, u0_idxs)
146-
147-
model
148114
end
115+
MTK.set_objective!(model::CasADiModel, expr) = minimize!(model.opti, MX(expr))
149116

150-
function set_casadi_bounds!(model, sys, pmap)
117+
function MTK.set_variable_bounds!(model, sys, pmap, tf)
151118
@unpack opti, U, V = model
152119
for (i, u) in enumerate(unknowns(sys))
153120
if MTK.hasbounds(u)
@@ -163,36 +130,56 @@ function set_casadi_bounds!(model, sys, pmap)
163130
subject_to!(opti, V.u[i, :] <= Symbolics.fast_substitute(hi, pmap))
164131
end
165132
end
133+
if MTK.symbolic_type(tf) === MTK.ScalarSymbolic() && hasbounds(tf)
134+
lo, hi = MTK.getbounds(tf)
135+
subject_to!(opti, model.tₛ >= lo)
136+
subject_to!(opti, model.tₛ <= hi)
137+
end
166138
end
167139

168-
function add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
140+
function MTK.add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
169141
@unpack opti, U = model
170142
for i in u0_idxs
171143
subject_to!(opti, U.u[i, 1] == u0[i])
172144
end
173145
end
174146

175-
function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
147+
function MTK.substitute_model_vars(
148+
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
176149
@unpack opti, U, V, tₛ = model
177-
178150
iv = MTK.get_iv(sys)
179-
jconstraints = MTK.get_constraints(sys)
180-
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
181-
182-
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
183-
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
184-
cons_dvs, cons_ps = MTK.process_constraint_system(
185-
jconstraints, Set(unknowns(sys)), parameters(sys), iv; validate = false)
186-
187-
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in cons_dvs])
188-
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints; is_free_t, auxmap)
189-
# Manually substitute fixed-t variables
190-
for (i, cons) in enumerate(jconstraints)
191-
consvars = MTK.vars(cons)
192-
for st in consvars
151+
sts = unknowns(sys)
152+
cts = MTK.unbound_inputs(sys)
153+
154+
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
155+
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
156+
157+
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
158+
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
159+
# tf means different things in different contexts; a [tf] in a cost function
160+
# should be tₛ, while a x(tf) should translate to x[1]
161+
if is_free_t
162+
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
163+
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
164+
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
165+
end
166+
167+
exprs = substitute_fixed_t_vars(exprs)
168+
169+
# for variables like x(t)
170+
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
171+
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
172+
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
173+
exprs
174+
end
175+
176+
function substitute_fixed_t_vars(exprs)
177+
for i in 1:length(exprs)
178+
subvars = MTK.vars(exprs[i])
179+
for st in subvars
193180
MTK.iscall(st) || continue
194-
x = MTK.operation(st)
195-
t = only(MTK.arguments(st))
181+
x = operation(st)
182+
t = only(arguments(st))
196183
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
197184
if haskey(stidxmap, x(iv))
198185
idx = stidxmap[x(iv)]
@@ -201,52 +188,19 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
201188
idx = ctidxmap[x(iv)]
202189
cv = V
203190
end
204-
cons = Symbolics.substitute(cons, Dict(x(t) => cv(t)[idx]))
205-
end
206-
207-
if cons isa Equation
208-
subject_to!(opti, cons.lhs - cons.rhs == 0)
209-
elseif cons.relational_op === Symbolics.geq
210-
subject_to!(opti, cons.lhs - cons.rhs 0)
211-
else
212-
subject_to!(opti, cons.lhs - cons.rhs 0)
191+
exprs[i] = Symbolics.fast_substitute(exprs[i], Dict(x(t) => cv(t)[idx]))
213192
end
193+
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => cv(t)[idx]))
214194
end
215195
end
216196

217-
function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
218-
@unpack opti, U, V, tₛ = model
219-
jcosts = cost(sys)
220-
if Symbolics._iszero(jcosts)
221-
minimize!(opti, MX(0))
222-
return
223-
end
224-
225-
iv = MTK.get_iv(sys)
226-
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
227-
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
228-
229-
jcosts = substitute_casadi_vars(model, sys, pmap, [jcosts]; is_free_t)[1]
230-
# Substitute fixed-time variables.
231-
costvars = MTK.vars(jcosts)
232-
for st in costvars
233-
MTK.iscall(st) || continue
234-
x = operation(st)
235-
t = only(arguments(st))
236-
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
237-
if haskey(stidxmap, x(iv))
238-
idx = stidxmap[x(iv)]
239-
cv = U
240-
else
241-
idx = ctidxmap[x(iv)]
242-
cv = V
243-
end
244-
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => cv(t)[idx]))
245-
end
197+
MTK.substitute_differentials(model::CasADiModel, exprs, args...) = exprs
246198

199+
function MTK.substitute_integral(model::CasADiModel, exprs)
200+
@unpack U, opti = model
247201
dt = U.t[2] - U.t[1]
248202
intmap = Dict()
249-
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
203+
for int in MTK.collect_applied_operators(exprs, Symbolics.Integral)
250204
op = MTK.operation(int)
251205
arg = only(arguments(MTK.value(int)))
252206
lo, hi = (op.domain.domain.left, op.domain.domain.right)
@@ -255,39 +209,11 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
255209
# Approximate integral as sum.
256210
intmap[int] = dt * tₛ * sum(arg)
257211
end
258-
jcosts = Symbolics.substitute(jcosts, intmap)
259-
jcosts = MTK.value(jcosts)
260-
minimize!(opti, MX(jcosts))
261-
end
262-
263-
function substitute_casadi_vars(
264-
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
265-
@unpack opti, U, V, tₛ = model
266-
iv = MTK.get_iv(sys)
267-
sts = unknowns(sys)
268-
cts = MTK.unbound_inputs(sys)
269-
270-
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
271-
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
272-
273-
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
274-
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
275-
# tf means different things in different contexts; a [tf] in a cost function
276-
# should be tₛ, while a x(tf) should translate to x[1]
277-
if is_free_t
278-
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
279-
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
280-
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
281-
end
282-
283-
# for variables like x(t)
284-
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
285-
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
286-
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
287-
exprs
212+
exprs = map(c -> Symbolics.substitute(c, intmap), exprs)
213+
exprs = MTK.value.(exprs)
288214
end
289215

290-
function add_solve_constraints(prob, tableau)
216+
function add_solve_constraints!(prob, tableau)
291217
@unpack A, α, c = tableau
292218
@unpack model, f, p = prob
293219
@unpack opti, U, V, tₛ = model
@@ -332,57 +258,29 @@ function add_solve_constraints(prob, tableau)
332258
solver_opti
333259
end
334260

335-
"""
336-
solve(prob::CasADiDynamicOptProblem, casadi_solver, ode_solver; plugin_options, solver_options, silent)
337-
338-
`plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
339-
340-
NOTE: the solver should be passed in as a string to CasADi. "ipopt"
341-
"""
342-
function DiffEqBase.solve(
343-
prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt",
344-
tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(),
345-
solver_options::Dict = Dict(), silent = false)
346-
@unpack model, u0, p, tspan, f = prob
347-
tableau = tableau_getter()
348-
@unpack opti, U, V, tₛ = model
349-
261+
function MTK.prepare_solver()
350262
opti = add_solve_constraints(prob, tableau)
351-
silent && (solver_options["print_level"] = 0)
352263
solver!(opti, "$solver", plugin_options, solver_options)
264+
end
265+
function MTK.get_U_values()
266+
U_vals = value_getter(U.u)
267+
size(U_vals, 2) == 1 && (U_vals = U_vals')
268+
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
269+
end
270+
function MTK.get_V_values()
271+
end
272+
function MTK.get_t_values()
273+
ts = value_getter(tₛ) * U.t
274+
end
353275

354-
failed = false
355-
value_getter = nothing
356-
sol = nothing
276+
function MTK.optimize_model!()
357277
try
358278
sol = CasADi.solve!(opti)
359279
value_getter = x -> CasADi.value(sol, x)
360280
catch ErrorException
361281
value_getter = x -> CasADi.debug_value(opti, x)
362282
failed = true
363283
end
364-
365-
ts = value_getter(tₛ) * U.t
366-
U_vals = value_getter(U.u)
367-
size(U_vals, 2) == 1 && (U_vals = U_vals')
368-
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
369-
ode_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, U_vals)
370-
371-
input_sol = nothing
372-
if prod(size(V.u)) != 0
373-
V_vals = value_getter(V.u)
374-
size(V_vals, 2) == 1 && (V_vals = V_vals')
375-
V_vals = [[V_vals[i, j] for i in 1:size(V_vals, 1)] for j in 1:length(ts)]
376-
input_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, V_vals)
377-
end
378-
379-
if failed
380-
ode_sol = SciMLBase.solution_new_retcode(
381-
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
382-
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
383-
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
384-
end
385-
386-
DynamicOptSolution(model, ode_sol, input_sol)
387284
end
285+
MTK.successful_solve() = true
388286
end

0 commit comments

Comments
 (0)