@@ -6,10 +6,8 @@ using UnPack
66using NaNMath
77const MTK = ModelingToolkit
88
9- # NaNMath
109for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
1110 f = nameof (ff)
12- # These need to be defined so that JuMP can trace through functions built by Symbolics
1311 @eval NaNMath.$ f (x:: CasadiSymbolicObject ) = Base.$ f (x)
1412end
1513
@@ -76,75 +74,47 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
7674 dt = nothing ,
7775 steps = nothing ,
7876 guesses = Dict (), kwargs... )
79- MTK. warn_overdetermined (sys, u0map)
80- _u0map = has_alg_eqs (sys) ? u0map : merge (Dict (u0map), Dict (guesses))
81- f, u0, p = MTK. process_SciMLProblem (ODEInputFunction, sys, _u0map, pmap;
82- t = tspan != = nothing ? tspan[1 ] : tspan, output_type = MX, kwargs... )
83-
84- pmap = Dict {Any, Any} (pmap)
85- steps, is_free_t = MTK. process_tspan (tspan, dt, steps)
86- model = init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t)
87-
88- CasADiDynamicOptProblem (f, u0, tspan, p, model, kwargs... )
77+ process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
8978end
9079
9180MTK. generate_internal_model (:: Type{CasADiModel} ) = CasADi. opti ()
92- MTK. generate_state_variable (model, u0, ns, nt)
93- MTK. generate_input_variable (model, c0, nc, nt) = 1
94- MTK. generate_timescale (model, dims) = 1
9581
96- function init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t = false )
97- ctrls = MTK. unbound_inputs (sys)
98- states = unknowns (sys)
99- opti = CasADi. Opti ()
82+ function MTK. generate_state_variable (model:: Opti , u0, ns, nt, tsteps)
83+ U = CasADi. variable! (model, ns, nt)
84+ set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
85+ MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
86+ end
87+
88+ function MTK. generate_input_variable (model:: Opti , c0, nc, nt, tsteps)
89+ V = CasADi. variable! (model, nc, nt)
90+ ! isempty (c0) && set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
91+ MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
92+ end
10093
94+ function MTK. generate_timescale (model:: Opti , guess, is_free_t)
10195 if is_free_t
102- (ts_sym, te_sym) = tspan
103- MTK. symbolic_type (ts_sym) != = MTK. NotSymbolic () &&
104- error (" Free initial time problems are not currently supported in CasADiDynamicOptProblem." )
105- tₛ = variable! (opti)
106- set_initial! (opti, tₛ, pmap[te_sym])
107- subject_to! (opti, tₛ >= ts_sym)
108- hasbounds (te_sym) && begin
109- lo, hi = getbounds (te_sym)
110- subject_to! (opti, tₛ >= lo)
111- subject_to! (opti, tₛ >= hi)
112- end
113- pmap[te_sym] = tₛ
114- tsteps = LinRange (0 , 1 , steps)
96+ tₛ = variable! (model)
97+ set_initial! (model, tₛ, guess)
98+ subject_to! (model, tₛ >= 0 )
99+ tₛ
115100 else
116- tₛ = MX (1 )
117- tsteps = LinRange (tspan[1 ], tspan[2 ], steps)
101+ MX (1 )
118102 end
103+ end
119104
120- U = CasADi. variable! (opti, length (states), steps)
121- V = CasADi. variable! (opti, length (ctrls), steps)
122- set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
123- c0 = MTK. value .([pmap[c] for c in ctrls])
124- ! isempty (c0) && set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
125-
126- U_interp = MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
127- V_interp = MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
128- for (i, ct) in enumerate (ctrls)
129- pmap[ct] = V[i, :]
105+ function MTK. add_constraint! (model:: CasADiModel , expr)
106+ @unpack opti = model
107+ if cons isa Equation
108+ subject_to! (opti, expr. lhs - expr. rhs == 0 )
109+ elseif cons. relational_op === Symbolics. geq
110+ subject_to! (opti, expr. lhs - expr. rhs ≥ 0 )
111+ else
112+ subject_to! (opti, expr. lhs - expr. rhs ≤ 0 )
130113 end
131-
132- model = CasADiModel (opti, U_interp, V_interp, tₛ)
133-
134- set_casadi_bounds! (model, sys, pmap)
135- add_cost_function! (model, sys, tspan, pmap; is_free_t)
136- add_user_constraints! (model, sys, tspan, pmap; is_free_t)
137-
138- stidxmap = Dict ([v => i for (i, v) in enumerate (states)])
139- u0map = Dict ([MTK. default_toterm (MTK. value (k)) => v for (k, v) in u0map])
140- u0_idxs = has_alg_eqs (sys) ? collect (1 : length (states)) :
141- [stidxmap[MTK. default_toterm (k)] for (k, v) in u0map]
142- add_initial_constraints! (model, u0, u0_idxs)
143-
144- model
145114end
115+ MTK. set_objective! (model:: CasADiModel , expr) = minimize! (model. opti, MX (expr))
146116
147- function set_casadi_bounds ! (model, sys, pmap)
117+ function MTK . set_variable_bounds ! (model, sys, pmap, tf )
148118 @unpack opti, U, V = model
149119 for (i, u) in enumerate (unknowns (sys))
150120 if MTK. hasbounds (u)
@@ -160,75 +130,53 @@ function set_casadi_bounds!(model, sys, pmap)
160130 subject_to! (opti, V. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
161131 end
162132 end
133+ if MTK. symbolic_type (tf) === MTK. ScalarSymbolic () && hasbounds (tf)
134+ lo, hi = MTK. getbounds (tf)
135+ subject_to! (opti, model. tₛ >= lo)
136+ subject_to! (opti, model. tₛ <= hi)
137+ end
163138end
164139
165- function add_initial_constraints! (model:: CasADiModel , u0, u0_idxs)
140+ function MTK . add_initial_constraints! (model:: CasADiModel , u0, u0_idxs)
166141 @unpack opti, U = model
167142 for i in u0_idxs
168143 subject_to! (opti, U. u[i, 1 ] == u0[i])
169144 end
170145end
171146
172- function add_user_constraints! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
147+ function MTK. substitute_model_vars (
148+ model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
173149 @unpack opti, U, V, tₛ = model
174-
175150 iv = MTK. get_iv (sys)
176- conssys = MTK. get_constraintsystem (sys)
177- jconstraints = isnothing (conssys) ? nothing : MTK. get_constraints (conssys)
178- (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
179-
180- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
181- ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
182- cons_unknowns = map (MTK. default_toterm, unknowns (conssys))
183-
184- auxmap = Dict ([u => MTK. default_toterm (MTK. value (u)) for u in unknowns (conssys)])
185- jconstraints = substitute_casadi_vars (model, sys, pmap, jconstraints; is_free_t, auxmap)
186- # Manually substitute fixed-t variables
187- for (i, cons) in enumerate (jconstraints)
188- consvars = MTK. vars (cons)
189- for st in consvars
190- MTK. iscall (st) || continue
191- x = MTK. operation (st)
192- t = only (MTK. arguments (st))
193- MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
194- if haskey (stidxmap, x (iv))
195- idx = stidxmap[x (iv)]
196- cv = U
197- else
198- idx = ctidxmap[x (iv)]
199- cv = V
200- end
201- cons = Symbolics. substitute (cons, Dict (x (t) => cv (t)[idx]))
202- end
151+ sts = unknowns (sys)
152+ cts = MTK. unbound_inputs (sys)
203153
204- if cons isa Equation
205- subject_to! (opti, cons. lhs - cons. rhs == 0 )
206- elseif cons. relational_op === Symbolics. geq
207- subject_to! (opti, cons. lhs - cons. rhs ≥ 0 )
208- else
209- subject_to! (opti, cons. lhs - cons. rhs ≤ 0 )
210- end
211- end
212- end
154+ x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
155+ c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
213156
214- function add_cost_function! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
215- @unpack opti, U, V, tₛ = model
216- jcosts = copy (MTK. get_costs (sys))
217- consolidate = MTK. get_consolidate (sys)
218- if isnothing (jcosts) || isempty (jcosts)
219- minimize! (opti, MX (0 ))
220- return
157+ exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
158+ exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
159+ # tf means different things in different contexts; a [tf] in a cost function
160+ # should be tₛ, while a x(tf) should translate to x[1]
161+ if is_free_t
162+ free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
163+ [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
164+ exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
221165 end
222166
223- iv = MTK. get_iv (sys)
224- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
225- ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
226-
227- jcosts = substitute_casadi_vars (model, sys, pmap, jcosts; is_free_t)
228- # Substitute fixed-time variables.
229- for i in 1 : length (jcosts)
230- costvars = MTK. vars (jcosts[i])
231- for st in costvars
167+ exprs = substitute_fixed_t_vars (exprs)
168+
169+ # for variables like x(t)
170+ whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
171+ [v => V. u[i, :] for (i, v) in enumerate (cts)]])
172+ exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
173+ exprs
174+ end
175+
176+ function substitute_fixed_t_vars (exprs)
177+ for i in 1 : length (exprs)
178+ subvars = MTK. vars (exprs[i])
179+ for st in subvars
232180 MTK. iscall (st) || continue
233181 x = operation (st)
234182 t = only (arguments (st))
@@ -240,13 +188,18 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
240188 idx = ctidxmap[x (iv)]
241189 cv = V
242190 end
243- jcosts [i] = Symbolics. substitute (jcosts [i], Dict (x (t) => cv (t)[idx]))
191+ exprs [i] = Symbolics. fast_substitute (exprs [i], Dict (x (t) => cv (t)[idx]))
244192 end
245193 end
194+ end
195+
196+ MTK. substitute_differentials (model:: CasADiModel , exprs, args... ) = exprs
246197
198+ function MTK. substitute_integral (model:: CasADiModel , exprs)
199+ @unpack U, opti = model
247200 dt = U. t[2 ] - U. t[1 ]
248201 intmap = Dict ()
249- for int in MTK. collect_applied_operators (jcosts , Symbolics. Integral)
202+ for int in MTK. collect_applied_operators (exprs , Symbolics. Integral)
250203 op = MTK. operation (int)
251204 arg = only (arguments (MTK. value (int)))
252205 lo, hi = (op. domain. domain. left, op. domain. domain. right)
@@ -255,39 +208,11 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
255208 # Approximate integral as sum.
256209 intmap[int] = dt * tₛ * sum (arg)
257210 end
258- jcosts = map (c -> Symbolics. substitute (c, intmap), jcosts)
259- jcosts = MTK. value .(jcosts)
260- minimize! (opti, MX (MTK. value (consolidate (jcosts))))
211+ exprs = map (c -> Symbolics. substitute (c, intmap), exprs)
212+ exprs = MTK. value .(exprs)
261213end
262214
263- function substitute_casadi_vars (
264- model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
265- @unpack opti, U, V, tₛ = model
266- iv = MTK. get_iv (sys)
267- sts = unknowns (sys)
268- cts = MTK. unbound_inputs (sys)
269-
270- x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
271- c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
272-
273- exprs = map (c -> Symbolics. fixpoint_sub (c, auxmap), exprs)
274- exprs = map (c -> Symbolics. fixpoint_sub (c, Dict (pmap)), exprs)
275- # tf means different things in different contexts; a [tf] in a cost function
276- # should be tₛ, while a x(tf) should translate to x[1]
277- if is_free_t
278- free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
279- [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
280- exprs = map (c -> Symbolics. fixpoint_sub (c, free_t_map), exprs)
281- end
282-
283- # for variables like x(t)
284- whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
285- [v => V. u[i, :] for (i, v) in enumerate (cts)]])
286- exprs = map (c -> Symbolics. fixpoint_sub (c, whole_interval_map), exprs)
287- exprs
288- end
289-
290- function add_solve_constraints (prob, tableau)
215+ function add_solve_constraints! (prob, tableau)
291216 @unpack A, α, c = tableau
292217 @unpack model, f, p = prob
293218 @unpack opti, U, V, tₛ = model
@@ -332,57 +257,29 @@ function add_solve_constraints(prob, tableau)
332257 solver_opti
333258end
334259
335- """
336- solve(prob::CasADiDynamicOptProblem, casadi_solver, ode_solver; plugin_options, solver_options, silent)
337-
338- `plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
339-
340- NOTE: the solver should be passed in as a string to CasADi. "ipopt"
341- """
342- function DiffEqBase. solve (
343- prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} = " ipopt" ,
344- tableau_getter = MTK. constructDefault; plugin_options:: Dict = Dict (),
345- solver_options:: Dict = Dict (), silent = false )
346- @unpack model, u0, p, tspan, f = prob
347- tableau = tableau_getter ()
348- @unpack opti, U, V, tₛ = model
349-
260+ function MTK. prepare_solver ()
350261 opti = add_solve_constraints (prob, tableau)
351- silent && (solver_options[" print_level" ] = 0 )
352262 solver! (opti, " $solver " , plugin_options, solver_options)
263+ end
264+ function MTK. get_U_values ()
265+ U_vals = value_getter (U. u)
266+ size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
267+ U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 )] for j in 1 : length (ts)]
268+ end
269+ function MTK. get_V_values ()
270+ end
271+ function MTK. get_t_values ()
272+ ts = value_getter (tₛ) * U. t
273+ end
353274
354- failed = false
355- value_getter = nothing
356- sol = nothing
275+ function MTK. optimize_model! ()
357276 try
358277 sol = CasADi. solve! (opti)
359278 value_getter = x -> CasADi. value (sol, x)
360279 catch ErrorException
361280 value_getter = x -> CasADi. debug_value (opti, x)
362281 failed = true
363282 end
364-
365- ts = value_getter (tₛ) * U. t
366- U_vals = value_getter (U. u)
367- size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
368- U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 )] for j in 1 : length (ts)]
369- ode_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, U_vals)
370-
371- input_sol = nothing
372- if prod (size (V. u)) != 0
373- V_vals = value_getter (V. u)
374- size (V_vals, 2 ) == 1 && (V_vals = V_vals' )
375- V_vals = [[V_vals[i, j] for i in 1 : size (V_vals, 1 )] for j in 1 : length (ts)]
376- input_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, V_vals)
377- end
378-
379- if failed
380- ode_sol = SciMLBase. solution_new_retcode (
381- ode_sol, SciMLBase. ReturnCode. ConvergenceFailure)
382- ! isnothing (input_sol) && (input_sol = SciMLBase. solution_new_retcode (
383- input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
384- end
385-
386- DynamicOptSolution (model, ode_sol, input_sol)
387283end
284+ MTK. successful_solve () = true
388285end
0 commit comments