Skip to content

Commit 85e3ed6

Browse files
committed
feat: implementation for CasADiProblem
1 parent b1bc160 commit 85e3ed6

File tree

3 files changed

+219
-42
lines changed

3 files changed

+219
-42
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6565

6666
[weakdeps]
6767
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
68+
CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f"
6869
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
70+
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
6971
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
7072
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
7173
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
@@ -75,6 +77,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
7577

7678
[extensions]
7779
MTKBifurcationKitExt = "BifurcationKit"
80+
MTKCasADiDynamicOptExt = ["CasADi", "DataInterpolations", "DiffEqDevTools"]
7881
MTKChainRulesCoreExt = "ChainRulesCore"
7982
MTKDeepDiffsExt = "DeepDiffs"
8083
MTKFMIExt = "FMI"
@@ -89,6 +92,7 @@ BifurcationKit = "0.4"
8992
BlockArrays = "1.1"
9093
BoundaryValueDiffEqAscher = "1.1.0"
9194
BoundaryValueDiffEqMIRK = "1.4.0"
95+
CasADi = "1.0.0"
9296
ChainRulesCore = "1"
9397
Combinatorics = "1"
9498
CommonSolve = "0.2.4"

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 192 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ModelingToolkit
33
using CasADi
44
using DiffEqDevTools, DiffEqBase
55
using DataInterpolations
6+
using UnPack
67
const MTK = MOdelingToolkit
78

89
struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
@@ -22,16 +23,15 @@ end
2223

2324
struct CasADiModel
2425
opti::Opti
25-
U::MX
26-
V::MX
27-
end
28-
29-
struct TimedMX
26+
U::AbstractInterpolation
27+
V::AbstractInterpolation
28+
tₛ::Union{Number, MX}
3029
end
3130

3231
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
3332
dt = nothing,
34-
steps = nothing,
33+
steps = nothing,
34+
interpolation_method::AbstractInterpolation = LinearInterpolation,
3535
guesses = Dict(), kwargs...)
3636
MTK.warn_overdetermined(sys, u0map)
3737
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
@@ -43,73 +43,228 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
4343
model = init_model()
4444
end
4545

46-
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
46+
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false, interp_type::AbstractInterpolation)
4747
ctrls = MTK.unbound_inputs(sys)
4848
states = unknowns(sys)
49-
model = CasADi.Opti()
49+
opti = CasADi.Opti()
50+
51+
if is_free_t
52+
(ts_sym, te_sym) = tspan
53+
MTK.symbolic_type(ts_sym) !== MTK.NotSymbolic() &&
54+
error("Free initial time problems are not currently supported.")
55+
tₛ = variable!(opti)
56+
tsteps = LinRange(0, 1, steps)
57+
else
58+
tₛ = 1
59+
tsteps = LinRange(tspan[1], tspan[2], steps)
60+
end
5061

51-
U = CasADi.variable!(model, length(states), steps)
52-
V = CasADi.variable!(model, length(ctrls), steps)
62+
U = CasADi.variable!(opti, length(states), steps)
63+
V = CasADi.variable!(opti, length(ctrls), steps)
64+
65+
U_interp = interp_type(Matrix(U), tsteps)
66+
V_interp = interp_type(Matrix(V), tsteps)
67+
68+
CasADiModel(opti, U_interp, V_interp, tₛ)
5369
end
5470

55-
function add_initial_constraints!()
56-
71+
function set_casadi_bounds!(model, sys, pmap)
72+
@unpack opti, U, V = model
73+
for (i, u) in enumerate(unknowns(sys))
74+
if MTK.hasbounds(u)
75+
lo, hi = MTK.getbounds(u)
76+
subject_to!(opti, lo <= U[i, :] <= hi)
77+
end
78+
end
79+
for (i, v) in enumerate(MTK.unbound_inputs(sys))
80+
if MTK.hasbounds(v)
81+
lo, hi = MTK.getbounds(v)
82+
subject_to!(opti, lo <= V[i, :] <= hi)
83+
end
84+
end
85+
end
86+
87+
function add_initial_constraints!(model::CasADiModel, u0, u0_idxs, ts)
88+
@unpack opti, U = model
89+
for i in u0_idxs
90+
subject_to!(opti, U.u[i, 1] == u0[i])
91+
end
5792
end
5893

5994
function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
60-
95+
@unpack opti, U, V, tₛ = model
96+
97+
iv = get_iv(sys)
98+
conssys = MTK.get_constraintsystem(sys)
99+
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
100+
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
101+
102+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
103+
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
104+
cons_unknowns = map(MTK.default_toterm, unknowns(conssys))
105+
for st in cons_unknowns
106+
x = operation(st)
107+
t = only(argments(st))
108+
idx = stidxmap[x(iv)]
109+
110+
jconstraints = map(c -> Symbolics.substitute(c, Dict(x(t) => U(t)[idx])), jconstraints)
111+
end
112+
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints)
113+
114+
for (i, cons) in enumerate(jconstraints)
115+
if cons isa Equation
116+
subject_to!(opti, cons.lhs - cons.rhs==0)
117+
elseif cons.relational_op === Symbolics.geq
118+
subject_to!(model, cons.lhs - cons.rhs0)
119+
else
120+
subject_to!(model, cons.lhs - cons.rhs0)
121+
end
122+
end
61123
end
62124

63-
function add_cost_function!(model)
125+
function add_cost_function!(model::CasADiModel, sys, tspan, pmap)
126+
@unpack opti, U, V, tₛ = model
127+
jcosts = MTK.get_costs(sys)
128+
consolidate = MTK.get_consolidate(sys)
129+
130+
if isnothing(jcosts) || isempty(jcosts)
131+
minimize!(opti, 0)
132+
return
133+
end
134+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
135+
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
64136

137+
for i in 1:length(jcosts)
138+
vars = vars(jcosts[i])
139+
for st in vars
140+
x = operation(st)
141+
t = only(arguments(st))
142+
t isa Union{Num, MTK.Symbolic} && continue
143+
idx = stidxmap[x(iv)]
144+
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => U(t)[idx]))
145+
end
146+
end
147+
jcosts = substitute_casadi_vars(model::CasADiModel, sys, pmap, jcosts; auxmap)
148+
jcosts = map(
149+
c -> Symbolics.substitute(c, MTK.() => Symbolics.Integral(iv in tspan)), jcosts)
150+
151+
dt = U.t[2] - U.t[1]
152+
intmap = Dict()
153+
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
154+
op = MTK.operation(int)
155+
arg = only(arguments(MTK.value(int)))
156+
lo, hi = (op.domain.domain.left, op.domain.domain.right)
157+
(lo, hi) !== tspan && error("Non-whole interval bounds for integrals are not currently supported.")
158+
intmap[int] = dt * tₛ * sum(arg)
159+
end
160+
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
161+
minimize!(opti, consolidate(jcosts))
162+
end
163+
164+
function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = Dict())
165+
@unpack opti, U, V = model
166+
iv = MTK.get_iv(sys)
167+
sts = unknowns(sys)
168+
cts = MTK.unbound_inputs(sys)
169+
170+
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
171+
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
172+
173+
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
174+
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
175+
176+
# for variables like x(t)
177+
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
178+
[v => V.u[i, :] for (i, v) in enumerate(cts)]])
179+
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
180+
exprs
65181
end
66182

67183
function add_solve_constraints!(prob, tableau; is_free_t)
68-
A = tableau.A
69-
α = tableau.α
70-
c = tableau.c
71-
model = prob.model
72-
f = prob.f
73-
p = prob.p
184+
@unpack A, α, c = tableau
185+
@unpack model, f, p = prob
186+
@unpack opti, U, V, tₛ = model
74187

75-
opti = model.opti
76-
t = model[:t]
77-
tsteps = supports(t)
78-
tmax = tsteps[end]
79-
pop!(tsteps)
80-
tₛ = is_free_t ? model[:tf] : 1
188+
tsteps = U.t
81189
dt = tsteps[2] - tsteps[1]
82190

83-
U = model.U
84-
V = model.V
85191
nᵤ = length(U)
86192
nᵥ = length(V)
87193

88194
if is_explicit(tableau)
89195
K = Any[]
90-
for τ in tsteps
196+
for k in 1:length(tsteps)-1
91197
for (i, h) in enumerate(c)
92198
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
93-
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
94-
Vₙ = [V[i](τ) for i in 1:nᵥ]
199+
Uₙ = U.u[:, k] + ΔU*dt
200+
Vₙ = V.u[:, k]
95201
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
96202
push!(K, Kₙ)
97203
end
98204
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
99-
subject_to!(model, U[n](τ) + ΔU[n]==U[n](τ + dt))
100-
empty!(K)
205+
subject_to!(opti, U.u[:, k] + ΔU == U.u[:, k+1])
101206
end
102207
else
208+
ΔU_tot = dt * (K' * α)
209+
for k in 1:length(tsteps)-1
210+
Kᵢ = variable!(opti, length(α), nᵤ)
211+
ΔUs = A * Kᵢ # the stepsize at each stage of the implicit method
212+
for (i, h) in enumerate(c)
213+
ΔU = @view ΔUs[i, :]
214+
Uₙ = U.u[:,k] + ΔU
215+
Vₙ = V.u[:,k]
216+
subject_to!(opti, K[i,:] == tₛ * f(Uₙ, Vₙ, p, τ + h*dt))
217+
end
218+
ΔU_tot = dt*(Kᵢ'*α)
219+
subject_to!(opti, U.u[:, k] + ΔU_tot == U.u[:,k+1])
220+
end
103221
end
104222
end
105223

106-
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol}, ode_solver::Union{String, Symbol}; silent = false)
224+
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
225+
226+
"""
227+
solve(prob::CasADiDynamicOptProblem, casadi_solver, ode_solver; plugin_options, solver_options)
228+
229+
`plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
230+
"""
231+
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol}, ode_solver::Union{String, Symbol}; plugin_options::Dict = Dict(), solver_options::Dict = Dict(), silent = false)
107232
model = prob.model
108233
opti = model.opti
109234

110-
solver!(opti, solver)
111-
sol = solve(opti)
112-
DynamicOptSolution(model, sol, input_sol)
235+
solver!(opti, solver, plugin_options, solver_options)
236+
add_casadi_solve_constraints!(prob, tableau)
237+
solver!(cmodel, "$solver", plugin_options, solver_options)
238+
239+
failed = false
240+
try
241+
sol = solve(opti)
242+
value_getter = x -> CasADi.value(sol, x)
243+
catch ErrorException
244+
value_getter = x -> CasADi.debug_value(opti, x)
245+
failed = true
246+
continue
247+
end
248+
249+
ts = value_getter(tₛ) * U.t
250+
U_vals = value_getter(U)
251+
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:length(ts)]
252+
sol = DiffEqBase.build_solution(prob, ode_solver, ts, U_vals)
253+
254+
input_sol = nothing
255+
if !isempty(V)
256+
V_vals = value_getter(V)
257+
V_vals = [[V_vals[i][j] for i in 1:length(V_vals)] for j in 1:length(ts)]
258+
input_sol = DiffEqBase.build_solution(prob, ode_solver, ts, V_vals)
259+
end
260+
261+
if failed
262+
sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure)
263+
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
264+
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
265+
end
266+
267+
DynamicOptSolution(cmodel, sol, input_sol)
113268
end
114269

115270
end

test/downstream/dynamic_opt_systems.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
1-
function lotkavolterra()
2-
3-
end
1+
function build_lotkavolterra(; with_constraint = false)
2+
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
3+
@variables x(..) y(..)
4+
t = M.t_nounits
5+
D = M.D_nounits
46

5-
function ()
6-
7+
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
8+
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
9+
10+
tspan = (0.0, 1.0)
11+
parammap ==> 1.5, β => 1.0, γ => 3.0, δ => 1.0]
12+
13+
if with_constraint
14+
constr = [x(0.6) ~ 3.5, x(0.3) ~ 7.0]
15+
guess = [x(t) => 4.0, y(t) => 2.0]
16+
u0map = Pair[]
17+
else
18+
constr = nothing
19+
guess = Pair[]
20+
u0map = [x(t) => 4.0, y(t) => 2.0]
21+
end
22+
23+
@mtkbuild sys = ODESystem(eqs, t; constraints = constr)
24+
sys, u0map, tspan, parammap, guess
725
end
826

927
function rocket_fft()

0 commit comments

Comments
 (0)