Skip to content

Commit c32f08f

Browse files
committed
feat: all problems working
1 parent ed33790 commit c32f08f

File tree

3 files changed

+127
-35
lines changed

3 files changed

+127
-35
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@ using ModelingToolkit
33
using CasADi
44
using DiffEqBase
55
using UnPack
6+
using NaNMath
67
const MTK = ModelingToolkit
78

9+
# NaNMath
10+
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
11+
f = nameof(ff)
12+
# These need to be defined so that JuMP can trace through functions built by Symbolics
13+
@eval NaNMath.$f(x::CasadiSymbolicObject) = Base.$f(x)
14+
end
15+
816
# Default linear interpolation for MX objects, likely to change down the line when we support interpolation with the collocation polynomial.
917
struct MXLinearInterpolation
1018
u::MX
@@ -40,7 +48,11 @@ function (M::MXLinearInterpolation)(τ)
4048
Δ = nt - i + 1
4149

4250
(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
43-
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
51+
if i < length(M.t)
52+
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
53+
else
54+
M.u[:, i]
55+
end
4456
end
4557

4658
"""
@@ -83,8 +95,16 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
8395
if is_free_t
8496
(ts_sym, te_sym) = tspan
8597
MTK.symbolic_type(ts_sym) !== MTK.NotSymbolic() &&
86-
error("Free initial time problems are not currently supported.")
98+
error("Free initial time problems are not currently supported in CasADiDynamicOptProblem.")
8799
tₛ = variable!(opti)
100+
set_initial!(opti, tₛ, pmap[te_sym])
101+
subject_to!(opti, tₛ >= ts_sym)
102+
hasbounds(te_sym) && begin
103+
lo, hi = getbounds(te_sym)
104+
subject_to!(opti, tₛ >= lo)
105+
subject_to!(opti, tₛ >= hi)
106+
end
107+
pmap[te_sym] = tₛ
88108
tsteps = LinRange(0, 1, steps)
89109
else
90110
tₛ = MX(1)
@@ -93,14 +113,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
93113

94114
U = CasADi.variable!(opti, length(states), steps)
95115
V = CasADi.variable!(opti, length(ctrls), steps)
116+
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
117+
c0 = MTK.value.([pmap[c] for c in ctrls])
118+
set_initial!(opti, V, DM(repeat(c0, 1, steps)))
119+
96120
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2]-tsteps[1])
97121
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2]-tsteps[1])
122+
for (i, ct) in enumerate(ctrls)
123+
pmap[ct] = V[i, :]
124+
end
98125

99126
model = CasADiModel(opti, U_interp, V_interp, tₛ)
100127

101128
set_casadi_bounds!(model, sys, pmap)
102-
add_cost_function!(model, sys, (tspan[1], tspan[2]), pmap)
103-
add_user_constraints!(model, sys, pmap; is_free_t)
129+
add_cost_function!(model, sys, tspan, pmap; is_free_t)
130+
add_user_constraints!(model, sys, tspan, pmap; is_free_t)
104131

105132
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
106133
u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map])
@@ -116,13 +143,15 @@ function set_casadi_bounds!(model, sys, pmap)
116143
for (i, u) in enumerate(unknowns(sys))
117144
if MTK.hasbounds(u)
118145
lo, hi = MTK.getbounds(u)
119-
subject_to!(opti, lo <= U[i, :] <= hi)
146+
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
147+
subject_to!(opti, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
120148
end
121149
end
122150
for (i, v) in enumerate(MTK.unbound_inputs(sys))
123151
if MTK.hasbounds(v)
124152
lo, hi = MTK.getbounds(v)
125-
subject_to!(opti, lo <= V[i, :] <= hi)
153+
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
154+
subject_to!(opti, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
126155
end
127156
end
128157
end
@@ -134,7 +163,7 @@ function add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
134163
end
135164
end
136165

137-
function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
166+
function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
138167
@unpack opti, U, V, tₛ = model
139168

140169
iv = MTK.get_iv(sys)
@@ -143,18 +172,29 @@ function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
143172
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
144173

145174
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
175+
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
146176
cons_unknowns = map(MTK.default_toterm, unknowns(conssys))
147-
for st in cons_unknowns
148-
x = MTK.operation(st)
149-
t = only(MTK.arguments(st))
150-
idx = stidxmap[x(iv)]
151-
@show t
152-
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
153-
jconstraints = map(c -> Symbolics.substitute(c, Dict(x(t) => U(t)[idx])), jconstraints)
154-
end
155-
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints)
156177

178+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
179+
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints; is_free_t, auxmap)
180+
# Manually substitute fixed-t variables
157181
for (i, cons) in enumerate(jconstraints)
182+
consvars = MTK.vars(cons)
183+
for st in consvars
184+
MTK.iscall(st) || continue
185+
x = MTK.operation(st)
186+
t = only(MTK.arguments(st))
187+
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
188+
if haskey(stidxmap, x(iv))
189+
idx = stidxmap[x(iv)]
190+
cv = U
191+
else
192+
idx = ctidxmap[x(iv)]
193+
cv = V
194+
end
195+
cons = Symbolics.substitute(cons, Dict(x(t) => cv(t)[idx]))
196+
end
197+
158198
if cons isa Equation
159199
subject_to!(opti, cons.lhs - cons.rhs==0)
160200
elseif cons.relational_op === Symbolics.geq
@@ -165,45 +205,56 @@ function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
165205
end
166206
end
167207

168-
function add_cost_function!(model::CasADiModel, sys, tspan, pmap)
208+
function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
169209
@unpack opti, U, V, tₛ = model
170-
jcosts = MTK.get_costs(sys)
210+
jcosts = copy(MTK.get_costs(sys))
171211
consolidate = MTK.get_consolidate(sys)
172-
173212
if isnothing(jcosts) || isempty(jcosts)
174213
minimize!(opti, MX(0))
175214
return
176215
end
177-
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
178-
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
179216

217+
iv = MTK.get_iv(sys)
218+
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
219+
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
220+
221+
jcosts = substitute_casadi_vars(model, sys, pmap, jcosts; is_free_t)
222+
# Substitute fixed-time variables.
180223
for i in 1:length(jcosts)
181-
vars = vars(jcosts[i])
182-
for st in vars
224+
costvars = MTK.vars(jcosts[i])
225+
for st in costvars
226+
MTK.iscall(st) || continue
183227
x = operation(st)
184228
t = only(arguments(st))
185-
t isa Union{Num, MTK.Symbolic} && continue
186-
idx = stidxmap[x(iv)]
187-
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => U(t)[idx]))
229+
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
230+
if haskey(stidxmap, x(iv))
231+
idx = stidxmap[x(iv)]
232+
cv = U
233+
else
234+
idx = ctidxmap[x(iv)]
235+
cv = V
236+
end
237+
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => cv(t)[idx]))
188238
end
189239
end
190-
jcosts = substitute_casadi_vars(model::CasADiModel, sys, pmap, jcosts; auxmap)
191240

192241
dt = U.t[2] - U.t[1]
193242
intmap = Dict()
194243
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
195244
op = MTK.operation(int)
196245
arg = only(arguments(MTK.value(int)))
197246
lo, hi = (op.domain.domain.left, op.domain.domain.right)
198-
(lo, hi) !== tspan && error("Non-whole interval bounds for integrals are not currently supported.")
247+
!isequal((lo, hi), tspan) && error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
248+
# Approximate integral as sum.
199249
intmap[int] = dt * tₛ * sum(arg)
200250
end
201251
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
202-
minimize!(opti, consolidate(jcosts))
252+
jcosts = MTK.value.(jcosts)
253+
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
203254
end
204255

205-
function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = Dict())
206-
@unpack opti, U, V = model
256+
function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
257+
@unpack opti, U, V, tₛ = model
207258
iv = MTK.get_iv(sys)
208259
sts = unknowns(sys)
209260
cts = MTK.unbound_inputs(sys)
@@ -213,6 +264,13 @@ function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = D
213264

214265
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
215266
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
267+
# tf means different things in different contexts; a [tf] in a cost function
268+
# should be tₛ, while a x(tf) should translate to x[1]
269+
if is_free_t
270+
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
271+
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
272+
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
273+
end
216274

217275
# for variables like x(t)
218276
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
@@ -221,7 +279,7 @@ function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = D
221279
exprs
222280
end
223281

224-
function add_solve_constraints(prob, tableau; is_free_t = false)
282+
function add_solve_constraints(prob, tableau)
225283
@unpack A, α, c = tableau
226284
@unpack model, f, p = prob
227285
@unpack opti, U, V, tₛ = model
@@ -283,6 +341,7 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
283341

284342
failed = false
285343
value_getter = nothing
344+
sol = nothing
286345
try
287346
sol = CasADi.solve!(opti)
288347
value_getter = x -> CasADi.value(sol, x)
@@ -293,22 +352,24 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
293352

294353
ts = value_getter(tₛ) * U.t
295354
U_vals = value_getter(U.u)
355+
size(U_vals, 2) == 1 && (U_vals = U_vals')
296356
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
297-
sol = DiffEqBase.build_solution(prob, tableau_getter, ts, U_vals)
357+
ode_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, U_vals)
298358

299359
input_sol = nothing
300360
if prod(size(V.u)) != 0
301361
V_vals = value_getter(V.u)
362+
size(V_vals, 2) == 1 && (V_vals = V_vals')
302363
V_vals = [[V_vals[i, j] for i in 1:size(V_vals, 1)] for j in 1:length(ts)]
303364
input_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, V_vals)
304365
end
305366

306367
if failed
307-
sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure)
368+
ode_sol = SciMLBase.solution_new_retcode(ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
308369
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
309370
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
310371
end
311372

312-
DynamicOptSolution(model, sol, input_sol)
373+
DynamicOptSolution(model, ode_sol, input_sol)
313374
end
314375
end

ext/MTKInfiniteOptExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free
173173
return
174174
end
175175
jcosts = substitute_jump_vars(model, sys, pmap, jcosts; is_free_t)
176+
@show jcosts
176177
tₛ = is_free_t ? model[:tf] : 1
177178

178179
# Substitute integral

test/downstream/jump_control.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,22 @@ end
136136
@test is_bangbang(jsol.input_sol, [-1.0], [1.0])
137137
# Test reached final position.
138138
@test (jsol.sol.u[end][2], 0.25, rtol = 1e-5)
139+
140+
cprob = CasADiDynamicOptProblem(block, u0map, tspan, parammap; dt = 0.01)
141+
csol = solve(cprob, "ipopt", constructVerner8, silent = true)
142+
# Linear systems have bang-bang controls
143+
@test is_bangbang(csol.input_sol, [-1.0], [1.0])
144+
# Test reached final position.
145+
@test (csol.sol.u[end][2], 0.25, rtol = 1e-5)
146+
139147
# Test dynamics
140148
@parameters (u_interp::ConstantInterpolation)(..)
141149
@mtkbuild block_ode = ODESystem([D(x(t)) ~ v(t), D(v(t)) ~ u_interp(t)], t)
142150
spline = ctrl_to_spline(jsol.input_sol, ConstantInterpolation)
143151
oprob = ODEProblem(block_ode, u0map, tspan, [u_interp => spline])
144152
osol = solve(oprob, Vern8(), dt = 0.01, adaptive = false)
145153
@test (jsol.sol.u, osol.u, rtol = 0.05)
154+
@test (csol.sol.u, osol.u, rtol = 0.05)
146155

147156
iprob = InfiniteOptDynamicOptProblem(block, u0map, tspan, parammap; dt = 0.01)
148157
isol = solve(iprob, Ipopt.Optimizer; silent = true)
@@ -174,6 +183,9 @@ end
174183
iprob = InfiniteOptDynamicOptProblem(beesys, u0map, tspan, pmap, dt = 0.01)
175184
isol = solve(iprob, Ipopt.Optimizer; silent = true)
176185
@test is_bangbang(isol.input_sol, [0.0], [1.0])
186+
cprob = CasADiDynamicOptProblem(beesys, u0map, tspan, pmap; dt = 0.01)
187+
csol = solve(cprob, "ipopt", constructTsitouras5, silent = true)
188+
@test is_bangbang(csol.input_sol, [0.0], [1.0])
177189

178190
@parameters (α_interp::LinearInterpolation)(..)
179191
eqs = [D(w(t)) ~ -μ * w(t) + b * s * α_interp(t) * w(t),
@@ -186,6 +198,7 @@ end
186198
Dict(α_interp => ctrl_to_spline(jsol.input_sol, LinearInterpolation))))
187199
osol = solve(oprob, Tsit5(); dt = 0.01, adaptive = false)
188200
@test (osol.u, jsol.sol.u, rtol = 0.01)
201+
@test (osol.u, csol.sol.u, rtol = 0.01)
189202
osol2 = solve(oprob, ImplicitEuler(); dt = 0.01, adaptive = false)
190203
@test (osol2.u, isol.sol.u, rtol = 0.01)
191204
end
@@ -216,6 +229,10 @@ end
216229
jprob = JuMPDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
217230
jsol = solve(jprob, Ipopt.Optimizer, constructRadauIIA5, silent = true)
218231
@test jsol.sol.u[end][1] > 1.012
232+
233+
cprob = CasADiDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
234+
csol = solve(cprob, "ipopt"; silent = true)
235+
@test csol.sol.u[end][1] > 1.012
219236

220237
iprob = InfiniteOptDynamicOptProblem(
221238
rocket, u0map, (ts, te), pmap; dt = 0.001)
@@ -232,6 +249,7 @@ end
232249
oprob = ODEProblem(rocket_ode, u0map, (ts, te), merge(Dict(pmap), interpmap))
233250
osol = solve(oprob, RadauIIA5(); adaptive = false, dt = 0.001)
234251
@test (jsol.sol.u, osol.u, rtol = 0.02)
252+
@test (csol.sol.u, osol.u, rtol = 0.02)
235253

236254
interpmap1 = Dict(T_interp => ctrl_to_spline(isol.input_sol, CubicSpline))
237255
oprob1 = ODEProblem(rocket_ode, u0map, (ts, te), merge(Dict(pmap), interpmap1))
@@ -258,6 +276,10 @@ end
258276
jsol = solve(jprob, Ipopt.Optimizer, constructTsitouras5, silent = true)
259277
@test isapprox(jsol.sol.t[end], 10.0, rtol = 1e-3)
260278

279+
cprob = CasADiDynamicOptProblem(rocket, u0map, (0, tf), pmap; steps = 201)
280+
csol = solve(cprob, "ipopt", constructTsitouras5, silent = true)
281+
@test isapprox(csol.sol.t[end], 10.0, rtol = 1e-3)
282+
261283
iprob = InfiniteOptDynamicOptProblem(rocket, u0map, (0, tf), pmap; steps = 200)
262284
isol = solve(iprob, Ipopt.Optimizer, silent = true)
263285
@test isapprox(isol.sol.t[end], 10.0, rtol = 1e-3)
@@ -279,6 +301,10 @@ end
279301
jsol = solve(jprob, Ipopt.Optimizer, constructVerner8, silent = true)
280302
@test isapprox(jsol.sol.t[end], 2.0, atol = 1e-5)
281303

304+
cprob = CasADiDynamicOptProblem(block, u0map, (0, tf), parammap; steps = 51)
305+
csol = solve(cprob, "ipopt", constructVerner8, silent = true)
306+
@test isapprox(csol.sol.t[end], 2.0, atol = 1e-5)
307+
282308
iprob = InfiniteOptDynamicOptProblem(block, u0map, tspan, parammap; steps = 51)
283309
isol = solve(iprob, Ipopt.Optimizer, silent = true)
284310
@test isapprox(isol.sol.t[end], 2.0, atol = 1e-5)
@@ -317,6 +343,10 @@ end
317343
jsol = solve(jprob, Ipopt.Optimizer, constructRK4, silent = true)
318344
@test jsol.sol.u[end] [π, 0, 0, 0]
319345

346+
cprob = CasADiDynamicOptProblem(cartpole, u0map, tspan, pmap; dt = 0.04)
347+
csol = solve(cprob, "ipopt", constructRK4, silent = true)
348+
@test csol.sol.u[end] [π, 0, 0, 0]
349+
320350
iprob = InfiniteOptDynamicOptProblem(cartpole, u0map, tspan, pmap; dt = 0.04)
321351
isol = solve(iprob, Ipopt.Optimizer, silent = true)
322352
@test isol.sol.u[end] [π, 0, 0, 0]

0 commit comments

Comments
 (0)