@@ -3,8 +3,16 @@ using ModelingToolkit
33using CasADi
44using DiffEqBase
55using UnPack
6+ using NaNMath
67const MTK = ModelingToolkit
78
9+ # NaNMath
10+ for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
11+ f = nameof (ff)
12+ # These need to be defined so that JuMP can trace through functions built by Symbolics
13+ @eval NaNMath.$ f (x:: CasadiSymbolicObject ) = Base.$ f (x)
14+ end
15+
816# Default linear interpolation for MX objects, likely to change down the line when we support interpolation with the collocation polynomial.
917struct MXLinearInterpolation
1018 u:: MX
@@ -40,7 +48,11 @@ function (M::MXLinearInterpolation)(τ)
4048 Δ = nt - i + 1
4149
4250 (i > length (M. t) || i < 1 ) && error (" Cannot extrapolate past the tspan." )
43- M. u[:, i] + Δ* (M. u[:, i + 1 ] - M. u[:, i])
51+ if i < length (M. t)
52+ M. u[:, i] + Δ* (M. u[:, i + 1 ] - M. u[:, i])
53+ else
54+ M. u[:, i]
55+ end
4456end
4557
4658"""
@@ -83,8 +95,16 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
8395 if is_free_t
8496 (ts_sym, te_sym) = tspan
8597 MTK. symbolic_type (ts_sym) != = MTK. NotSymbolic () &&
86- error (" Free initial time problems are not currently supported." )
98+ error (" Free initial time problems are not currently supported in CasADiDynamicOptProblem ." )
8799 tₛ = variable! (opti)
100+ set_initial! (opti, tₛ, pmap[te_sym])
101+ subject_to! (opti, tₛ >= ts_sym)
102+ hasbounds (te_sym) && begin
103+ lo, hi = getbounds (te_sym)
104+ subject_to! (opti, tₛ >= lo)
105+ subject_to! (opti, tₛ >= hi)
106+ end
107+ pmap[te_sym] = tₛ
88108 tsteps = LinRange (0 , 1 , steps)
89109 else
90110 tₛ = MX (1 )
@@ -93,14 +113,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
93113
94114 U = CasADi. variable! (opti, length (states), steps)
95115 V = CasADi. variable! (opti, length (ctrls), steps)
116+ set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
117+ c0 = MTK. value .([pmap[c] for c in ctrls])
118+ set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
119+
96120 U_interp = MXLinearInterpolation (U, tsteps, tsteps[2 ]- tsteps[1 ])
97121 V_interp = MXLinearInterpolation (V, tsteps, tsteps[2 ]- tsteps[1 ])
122+ for (i, ct) in enumerate (ctrls)
123+ pmap[ct] = V[i, :]
124+ end
98125
99126 model = CasADiModel (opti, U_interp, V_interp, tₛ)
100127
101128 set_casadi_bounds! (model, sys, pmap)
102- add_cost_function! (model, sys, ( tspan[ 1 ], tspan[ 2 ]), pmap)
103- add_user_constraints! (model, sys, pmap; is_free_t)
129+ add_cost_function! (model, sys, tspan, pmap; is_free_t )
130+ add_user_constraints! (model, sys, tspan, pmap; is_free_t)
104131
105132 stidxmap = Dict ([v => i for (i, v) in enumerate (states)])
106133 u0map = Dict ([MTK. default_toterm (MTK. value (k)) => v for (k, v) in u0map])
@@ -116,13 +143,15 @@ function set_casadi_bounds!(model, sys, pmap)
116143 for (i, u) in enumerate (unknowns (sys))
117144 if MTK. hasbounds (u)
118145 lo, hi = MTK. getbounds (u)
119- subject_to! (opti, lo <= U[i, :] <= hi)
146+ subject_to! (opti, Symbolics. fixpoint_sub (lo, pmap) <= U. u[i, :])
147+ subject_to! (opti, U. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
120148 end
121149 end
122150 for (i, v) in enumerate (MTK. unbound_inputs (sys))
123151 if MTK. hasbounds (v)
124152 lo, hi = MTK. getbounds (v)
125- subject_to! (opti, lo <= V[i, :] <= hi)
153+ subject_to! (opti, Symbolics. fixpoint_sub (lo, pmap) <= V. u[i, :])
154+ subject_to! (opti, V. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
126155 end
127156 end
128157end
@@ -134,7 +163,7 @@ function add_initial_constraints!(model::CasADiModel, u0, u0_idxs)
134163 end
135164end
136165
137- function add_user_constraints! (model:: CasADiModel , sys, pmap; is_free_t = false )
166+ function add_user_constraints! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
138167 @unpack opti, U, V, tₛ = model
139168
140169 iv = MTK. get_iv (sys)
@@ -143,18 +172,29 @@ function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
143172 (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
144173
145174 stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
175+ ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
146176 cons_unknowns = map (MTK. default_toterm, unknowns (conssys))
147- for st in cons_unknowns
148- x = MTK. operation (st)
149- t = only (MTK. arguments (st))
150- idx = stidxmap[x (iv)]
151- @show t
152- MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
153- jconstraints = map (c -> Symbolics. substitute (c, Dict (x (t) => U (t)[idx])), jconstraints)
154- end
155- jconstraints = substitute_casadi_vars (model, sys, pmap, jconstraints)
156177
178+ auxmap = Dict ([u => MTK. default_toterm (MTK. value (u)) for u in unknowns (conssys)])
179+ jconstraints = substitute_casadi_vars (model, sys, pmap, jconstraints; is_free_t, auxmap)
180+ # Manually substitute fixed-t variables
157181 for (i, cons) in enumerate (jconstraints)
182+ consvars = MTK. vars (cons)
183+ for st in consvars
184+ MTK. iscall (st) || continue
185+ x = MTK. operation (st)
186+ t = only (MTK. arguments (st))
187+ MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
188+ if haskey (stidxmap, x (iv))
189+ idx = stidxmap[x (iv)]
190+ cv = U
191+ else
192+ idx = ctidxmap[x (iv)]
193+ cv = V
194+ end
195+ cons = Symbolics. substitute (cons, Dict (x (t) => cv (t)[idx]))
196+ end
197+
158198 if cons isa Equation
159199 subject_to! (opti, cons. lhs - cons. rhs== 0 )
160200 elseif cons. relational_op === Symbolics. geq
@@ -165,45 +205,56 @@ function add_user_constraints!(model::CasADiModel, sys, pmap; is_free_t = false)
165205 end
166206end
167207
168- function add_cost_function! (model:: CasADiModel , sys, tspan, pmap)
208+ function add_cost_function! (model:: CasADiModel , sys, tspan, pmap; is_free_t )
169209 @unpack opti, U, V, tₛ = model
170- jcosts = MTK. get_costs (sys)
210+ jcosts = copy ( MTK. get_costs (sys) )
171211 consolidate = MTK. get_consolidate (sys)
172-
173212 if isnothing (jcosts) || isempty (jcosts)
174213 minimize! (opti, MX (0 ))
175214 return
176215 end
177- stidxmap = Dict ([v => i for (i, v) in enumerate (sts)])
178- pidxmap = Dict ([v => i for (i, v) in enumerate (ps)])
179216
217+ iv = MTK. get_iv (sys)
218+ stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
219+ ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
220+
221+ jcosts = substitute_casadi_vars (model, sys, pmap, jcosts; is_free_t)
222+ # Substitute fixed-time variables.
180223 for i in 1 : length (jcosts)
181- vars = vars (jcosts[i])
182- for st in vars
224+ costvars = MTK. vars (jcosts[i])
225+ for st in costvars
226+ MTK. iscall (st) || continue
183227 x = operation (st)
184228 t = only (arguments (st))
185- t isa Union{Num, MTK. Symbolic} && continue
186- idx = stidxmap[x (iv)]
187- jcosts[i] = Symbolics. substitute (jcosts[i], Dict (x (t) => U (t)[idx]))
229+ MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
230+ if haskey (stidxmap, x (iv))
231+ idx = stidxmap[x (iv)]
232+ cv = U
233+ else
234+ idx = ctidxmap[x (iv)]
235+ cv = V
236+ end
237+ jcosts[i] = Symbolics. substitute (jcosts[i], Dict (x (t) => cv (t)[idx]))
188238 end
189239 end
190- jcosts = substitute_casadi_vars (model:: CasADiModel , sys, pmap, jcosts; auxmap)
191240
192241 dt = U. t[2 ] - U. t[1 ]
193242 intmap = Dict ()
194243 for int in MTK. collect_applied_operators (jcosts, Symbolics. Integral)
195244 op = MTK. operation (int)
196245 arg = only (arguments (MTK. value (int)))
197246 lo, hi = (op. domain. domain. left, op. domain. domain. right)
198- (lo, hi) != = tspan && error (" Non-whole interval bounds for integrals are not currently supported." )
247+ ! isequal ((lo, hi), tspan) && error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." )
248+ # Approximate integral as sum.
199249 intmap[int] = dt * tₛ * sum (arg)
200250 end
201251 jcosts = map (c -> Symbolics. substitute (c, intmap), jcosts)
202- minimize! (opti, consolidate (jcosts))
252+ jcosts = MTK. value .(jcosts)
253+ minimize! (opti, MX (MTK. value (consolidate (jcosts))))
203254end
204255
205- function substitute_casadi_vars (model:: CasADiModel , sys, pmap, exprs; auxmap = Dict ())
206- @unpack opti, U, V = model
256+ function substitute_casadi_vars (model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t )
257+ @unpack opti, U, V, tₛ = model
207258 iv = MTK. get_iv (sys)
208259 sts = unknowns (sys)
209260 cts = MTK. unbound_inputs (sys)
@@ -213,6 +264,13 @@ function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = D
213264
214265 exprs = map (c -> Symbolics. fixpoint_sub (c, auxmap), exprs)
215266 exprs = map (c -> Symbolics. fixpoint_sub (c, Dict (pmap)), exprs)
267+ # tf means different things in different contexts; a [tf] in a cost function
268+ # should be tₛ, while a x(tf) should translate to x[1]
269+ if is_free_t
270+ free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
271+ [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
272+ exprs = map (c -> Symbolics. fixpoint_sub (c, free_t_map), exprs)
273+ end
216274
217275 # for variables like x(t)
218276 whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
@@ -221,7 +279,7 @@ function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap = D
221279 exprs
222280end
223281
224- function add_solve_constraints (prob, tableau; is_free_t = false )
282+ function add_solve_constraints (prob, tableau)
225283 @unpack A, α, c = tableau
226284 @unpack model, f, p = prob
227285 @unpack opti, U, V, tₛ = model
@@ -283,6 +341,7 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
283341
284342 failed = false
285343 value_getter = nothing
344+ sol = nothing
286345 try
287346 sol = CasADi. solve! (opti)
288347 value_getter = x -> CasADi. value (sol, x)
@@ -293,22 +352,24 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
293352
294353 ts = value_getter (tₛ) * U. t
295354 U_vals = value_getter (U. u)
355+ size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
296356 U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 )] for j in 1 : length (ts)]
297- sol = DiffEqBase. build_solution (prob, tableau_getter, ts, U_vals)
357+ ode_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, U_vals)
298358
299359 input_sol = nothing
300360 if prod (size (V. u)) != 0
301361 V_vals = value_getter (V. u)
362+ size (V_vals, 2 ) == 1 && (V_vals = V_vals' )
302363 V_vals = [[V_vals[i, j] for i in 1 : size (V_vals, 1 )] for j in 1 : length (ts)]
303364 input_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, V_vals)
304365 end
305366
306367 if failed
307- sol = SciMLBase. solution_new_retcode (sol , SciMLBase. ReturnCode. ConvergenceFailure)
368+ ode_sol = SciMLBase. solution_new_retcode (ode_sol , SciMLBase. ReturnCode. ConvergenceFailure)
308369 ! isnothing (input_sol) && (input_sol = SciMLBase. solution_new_retcode (
309370 input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
310371 end
311372
312- DynamicOptSolution (model, sol , input_sol)
373+ DynamicOptSolution (model, ode_sol , input_sol)
313374end
314375end
0 commit comments