Skip to content

Commit 65975d4

Browse files
committed
feat: all problems working
1 parent 341d175 commit 65975d4

File tree

2 files changed

+97
-35
lines changed

2 files changed

+97
-35
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@ using ModelingToolkit
33
using CasADi
44
using DiffEqBase
55
using UnPack
6+
using NaNMath
67
const MTK = ModelingToolkit
78

9+
# NaNMath
10+
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
11+
f = nameof(ff)
12+
# These need to be defined so that JuMP can trace through functions built by Symbolics
13+
@eval NaNMath.$f(x::CasadiSymbolicObject) = Base.$f(x)
14+
end
15+
816
# Default linear interpolation for MX objects, likely to change down the line when we support interpolation with the collocation polynomial.
917
struct MXLinearInterpolation
1018
u::MX
@@ -40,7 +48,11 @@ function (M::MXLinearInterpolation)(τ)
4048
Δ = nt - i + 1
4149

4250
(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
43-
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
51+
if i < length(M.t)
52+
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
53+
else
54+
M.u[:, i]
55+
end
4456
end
4557

4658
"""
@@ -83,8 +95,16 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
8395
if is_free_t
8496
(ts_sym, te_sym) = tspan
8597
MTK.symbolic_type(ts_sym) !== MTK.NotSymbolic() &&
86-
error("Free initial time problems are not currently supported.")
98+
error("Free initial time problems are not currently supported in CasADiDynamicOptProblem.")
8799
tₛ = variable!(opti)
100+
set_initial!(opti, tₛ, pmap[te_sym])
101+
subject_to!(opti, tₛ >= ts_sym)
102+
hasbounds(te_sym) && begin
103+
lo, hi = getbounds(te_sym)
104+
subject_to!(opti, tₛ >= lo)
105+
subject_to!(opti, tₛ >= hi)
106+
end
107+
pmap[te_sym] = tₛ
88108
tsteps = LinRange(0, 1, steps)
89109
else
90110
tₛ = MX(1)
@@ -93,14 +113,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
93113

94114
U = CasADi.variable!(opti, length(states), steps)
95115
V = CasADi.variable!(opti, length(ctrls), steps)
116+
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
117+
c0 = MTK.value.([pmap[c] for c in ctrls])
118+
set_initial!(opti, V, DM(repeat(c0, 1, steps)))
119+
96120
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2]-tsteps[1])
97121
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2]-tsteps[1])
122+
for (i, ct) in enumerate(ctrls)
123+
pmap[ct] = V[i, :]
124+
end
98125

99126
model = CasADiModel(opti, U_interp, V_interp, tₛ)
100127

101128
set_casadi_bounds!(model, sys, pmap)
102-
add_cost_function!(model, sys, (tspan[1], tspan[2]), pmap)
103-
add_user_constraints!(model, sys, pmap; is_free_t)
129+
add_cost_function!(model, sys, tspan, pmap; is_free_t)
130+
add_user_constraints!(model, sys, tspan, pmap; is_free_t)
104131

105132
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
106133
u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map])
@@ -116,13 +143,15 @@ function set_casadi_bounds!(model, sys, pmap)
116143
for (i, u) in enumerate(unknowns(sys))
117144
if MTK.hasbounds(u)
118145
lo, hi = MTK.getbounds(u)
119-
subject_to!(opti, lo <= U[i, :] <= hi)
146+
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
147+
subject_to!(opti, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
120148
end
121149
end
122150
for (i, v) in enumerate(MTK.unbound_inputs(sys))
123151
if MTK.hasbounds(v)
124152
lo, hi = MTK.getbounds(v)
125-
subject_to!(opti, lo <= V[i, :] <= hi)
153+
subject_to!(opti, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
154+
subject_to!(opti, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
126155
end
127156
end
128157
end
@@ -134,7 +163,7 @@ function add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
134163
end
135164
end
136165

137-
function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
166+
function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
138167
@unpack opti, U, V, tₛ = model
139168

140169
iv = MTK.get_iv(sys)
@@ -143,18 +172,29 @@ function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
143172
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
144173

145174
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
175+
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
146176
cons_unknowns = map(MTK.default_toterm, unknowns(conssys))
147-
for st in cons_unknowns
148-
x = MTK.operation(st)
149-
t = only(MTK.arguments(st))
150-
idx = stidxmap[x(iv)]
151-
@show t
152-
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
153-
jconstraints = map(c -> Symbolics.substitute(c, Dict(x(t) => U(t)[idx])), jconstraints)
154-
end
155-
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints)
156177

178+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
179+
jconstraints = substitute_casadi_vars(model, sys, pmap, jconstraints; is_free_t, auxmap)
180+
# Manually substitute fixed-t variables
157181
for (i, cons) in enumerate(jconstraints)
182+
consvars = MTK.vars(cons)
183+
for st in consvars
184+
MTK.iscall(st) || continue
185+
x = MTK.operation(st)
186+
t = only(MTK.arguments(st))
187+
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
188+
if haskey(stidxmap, x(iv))
189+
idx = stidxmap[x(iv)]
190+
cv = U
191+
else
192+
idx = ctidxmap[x(iv)]
193+
cv = V
194+
end
195+
cons = Symbolics.substitute(cons, Dict(x(t) => cv(t)[idx]))
196+
end
197+
158198
if cons isa Equation
159199
subject_to!(opti, cons.lhs - cons.rhs==0)
160200
elseif cons.relational_op === Symbolics.geq
@@ -165,45 +205,56 @@ function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
165205
end
166206
end
167207

168-
function add_cost_function!(model::CasADiModel, sys, tspan, pmap)
208+
function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
169209
@unpack opti, U, V, tₛ = model
170-
jcosts = MTK.get_costs(sys)
210+
jcosts = copy(MTK.get_costs(sys))
171211
consolidate = MTK.get_consolidate(sys)
172-
173212
if isnothing(jcosts) || isempty(jcosts)
174213
minimize!(opti, MX(0))
175214
return
176215
end
177-
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
178-
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
179216

217+
iv = MTK.get_iv(sys)
218+
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
219+
ctidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
220+
221+
jcosts = substitute_casadi_vars(model, sys, pmap, jcosts; is_free_t)
222+
# Substitute fixed-time variables.
180223
for i in 1:length(jcosts)
181-
vars = vars(jcosts[i])
182-
for st in vars
224+
costvars = MTK.vars(jcosts[i])
225+
for st in costvars
226+
MTK.iscall(st) || continue
183227
x = operation(st)
184228
t = only(arguments(st))
185-
t isa Union{Num, MTK.Symbolic} && continue
186-
idx = stidxmap[x(iv)]
187-
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => U(t)[idx]))
229+
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
230+
if haskey(stidxmap, x(iv))
231+
idx = stidxmap[x(iv)]
232+
cv = U
233+
else
234+
idx = ctidxmap[x(iv)]
235+
cv = V
236+
end
237+
jcosts[i] = Symbolics.substitute(jcosts[i], Dict(x(t) => cv(t)[idx]))
188238
end
189239
end
190-
jcosts = substitute_casadi_vars(model::CasADiModel, sys, pmap, jcosts; auxmap)
191240

192241
dt = U.t[2] - U.t[1]
193242
intmap = Dict()
194243
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
195244
op = MTK.operation(int)
196245
arg = only(arguments(MTK.value(int)))
197246
lo, hi = (op.domain.domain.left, op.domain.domain.right)
198-
(lo, hi) !== tspan && error("Non-whole interval bounds for integrals are not currently supported.")
247+
!isequal((lo, hi), tspan) && error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
248+
# Approximate integral as sum.
199249
intmap[int] = dt * tₛ * sum(arg)
200250
end
201251
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
202-
minimize!(opti, consolidate(jcosts))
252+
jcosts = MTK.value.(jcosts)
253+
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
203254
end
204255

205-
function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = Dict())
206-
@unpack opti, U, V = model
256+
function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
257+
@unpack opti, U, V, tₛ = model
207258
iv = MTK.get_iv(sys)
208259
sts = unknowns(sys)
209260
cts = MTK.unbound_inputs(sys)
@@ -213,6 +264,13 @@ function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = D
213264

214265
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
215266
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
267+
# tf means different things in different contexts; a [tf] in a cost function
268+
# should be tₛ, while a x(tf) should translate to x[1]
269+
if is_free_t
270+
free_t_map = Dict([[x(tₛ) => U.u[i, end] for (i, x) in enumerate(x_ops)];
271+
[c(tₛ) => V.u[i, end] for (i, c) in enumerate(c_ops)]])
272+
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
273+
end
216274

217275
# for variables like x(t)
218276
whole_interval_map = Dict([[v => U.u[i, :] for (i, v) in enumerate(sts)];
@@ -221,7 +279,7 @@ function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = D
221279
exprs
222280
end
223281

224-
function add_solve_constraints(prob, tableau; is_free_t = false)
282+
function add_solve_constraints(prob, tableau)
225283
@unpack A, α, c = tableau
226284
@unpack model, f, p = prob
227285
@unpack opti, U, V, tₛ = model
@@ -283,6 +341,7 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
283341

284342
failed = false
285343
value_getter = nothing
344+
sol = nothing
286345
try
287346
sol = CasADi.solve!(opti)
288347
value_getter = x -> CasADi.value(sol, x)
@@ -293,22 +352,24 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
293352

294353
ts = value_getter(tₛ) * U.t
295354
U_vals = value_getter(U.u)
355+
size(U_vals, 2) == 1 && (U_vals = U_vals')
296356
U_vals = [[U_vals[i, j] for i in 1:size(U_vals, 1)] for j in 1:length(ts)]
297-
sol = DiffEqBase.build_solution(prob, tableau_getter, ts, U_vals)
357+
ode_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, U_vals)
298358

299359
input_sol = nothing
300360
if prod(size(V.u)) != 0
301361
V_vals = value_getter(V.u)
362+
size(V_vals, 2) == 1 && (V_vals = V_vals')
302363
V_vals = [[V_vals[i, j] for i in 1:size(V_vals, 1)] for j in 1:length(ts)]
303364
input_sol = DiffEqBase.build_solution(prob, tableau_getter, ts, V_vals)
304365
end
305366

306367
if failed
307-
sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure)
368+
ode_sol = SciMLBase.solution_new_retcode(ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
308369
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
309370
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
310371
end
311372

312-
DynamicOptSolution(model, sol, input_sol)
373+
DynamicOptSolution(model, ode_sol, input_sol)
313374
end
314375
end

ext/MTKInfiniteOptExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free
173173
return
174174
end
175175
jcosts = substitute_jump_vars(model, sys, pmap, jcosts; is_free_t)
176+
@show jcosts
176177
tₛ = is_free_t ? model[:tf] : 1
177178

178179
# Substitute integral

0 commit comments

Comments
 (0)