@@ -6,10 +6,8 @@ using UnPack
6
6
using NaNMath
7
7
const MTK = ModelingToolkit
8
8
9
- # NaNMath
10
9
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
11
10
f = nameof (ff)
12
- # These need to be defined so that JuMP can trace through functions built by Symbolics
13
11
@eval NaNMath.$ f (x:: CasadiSymbolicObject ) = Base.$ f (x)
14
12
end
15
13
@@ -76,78 +74,47 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
76
74
dt = nothing ,
77
75
steps = nothing ,
78
76
guesses = Dict (), kwargs... )
79
- MTK. warn_overdetermined (sys, u0map)
80
- _u0map = has_alg_eqs (sys) ? MTK. to_varmap (u0map, unknowns (sys)) :
81
- merge (Dict (u0map), Dict (guesses))
82
- pmap = MTK. to_varmap (pmap, parameters (sys))
83
- f, u0, p = MTK. process_SciMLProblem (ODEInputFunction, sys, merge (_u0map, pmap);
84
- t = tspan != = nothing ? tspan[1 ] : tspan, output_type = MX, kwargs... )
85
-
86
- pmap = MTK. recursive_unwrap (MTK. AnyDict (pmap))
87
- MTK. evaluate_varmap! (pmap, keys (pmap))
88
- steps, is_free_t = MTK. process_tspan (tspan, dt, steps)
89
- model = init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t)
90
-
91
- CasADiDynamicOptProblem (f, u0, tspan, p, model, kwargs... )
77
+ process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
92
78
end
93
79
94
80
MTK. generate_internal_model (:: Type{CasADiModel} ) = CasADi. opti ()
95
- MTK. generate_state_variable (model, u0, ns, nt)
96
- MTK. generate_input_variable (model, c0, nc, nt) = 1
97
- MTK. generate_timescale (model, dims) = 1
98
81
99
- function init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t = false )
100
- ctrls = MTK. unbound_inputs (sys)
101
- states = unknowns (sys)
102
- opti = CasADi. Opti ()
82
+ function MTK. generate_state_variable (model:: Opti , u0, ns, nt, tsteps)
83
+ U = CasADi. variable! (model, ns, nt)
84
+ set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
85
+ MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
86
+ end
103
87
88
+ function MTK. generate_input_variable (model:: Opti , c0, nc, nt, tsteps)
89
+ V = CasADi. variable! (model, nc, nt)
90
+ ! isempty (c0) && set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
91
+ MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
92
+ end
93
+
94
+ function MTK. generate_timescale (model:: Opti , guess, is_free_t)
104
95
if is_free_t
105
- (ts_sym, te_sym) = tspan
106
- MTK. symbolic_type (ts_sym) != = MTK. NotSymbolic () &&
107
- error (" Free initial time problems are not currently supported in CasADiDynamicOptProblem." )
108
- tₛ = variable! (opti)
109
- set_initial! (opti, tₛ, pmap[te_sym])
110
- subject_to! (opti, tₛ >= ts_sym)
111
- hasbounds (te_sym) && begin
112
- lo, hi = getbounds (te_sym)
113
- subject_to! (opti, tₛ >= lo)
114
- subject_to! (opti, tₛ >= hi)
115
- end
116
- pmap[te_sym] = tₛ
117
- tsteps = LinRange (0 , 1 , steps)
96
+ tₛ = variable! (model)
97
+ set_initial! (model, tₛ, guess)
98
+ subject_to! (model, tₛ >= 0 )
99
+ tₛ
118
100
else
119
- tₛ = MX (1 )
120
- tsteps = LinRange (tspan[1 ], tspan[2 ], steps)
101
+ MX (1 )
121
102
end
103
+ end
122
104
123
- U = CasADi. variable! (opti, length (states), steps)
124
- V = CasADi. variable! (opti, length (ctrls), steps)
125
- set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
126
- c0 = MTK. value .([pmap[c] for c in ctrls])
127
- ! isempty (c0) && set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
128
-
129
- U_interp = MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
130
- V_interp = MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
131
- for (i, ct) in enumerate (ctrls)
132
- pmap[ct] = V[i, :]
105
+ function MTK. add_constraint! (model:: CasADiModel , expr)
106
+ @unpack opti = model
107
+ if cons isa Equation
108
+ subject_to! (opti, expr. lhs - expr. rhs == 0 )
109
+ elseif cons. relational_op === Symbolics. geq
110
+ subject_to! (opti, expr. lhs - expr. rhs ≥ 0 )
111
+ else
112
+ subject_to! (opti, expr. lhs - expr. rhs ≤ 0 )
133
113
end
134
-
135
- model = CasADiModel (opti, U_interp, V_interp, tₛ)
136
-
137
- set_casadi_bounds! (model, sys, pmap)
138
- add_cost_function! (model, sys, tspan, pmap; is_free_t)
139
- add_user_constraints! (model, sys, tspan, pmap; is_free_t)
140
-
141
- stidxmap = Dict ([v => i for (i, v) in enumerate (states)])
142
- u0map = Dict ([MTK. default_toterm (MTK. value (k)) => v for (k, v) in u0map])
143
- u0_idxs = has_alg_eqs (sys) ? collect (1 : length (states)) :
144
- [stidxmap[MTK. default_toterm (k)] for (k, v) in u0map]
145
- add_initial_constraints! (model, u0, u0_idxs)
146
-
147
- model
148
114
end
115
+ MTK. set_objective! (model:: CasADiModel , expr) = minimize! (model. opti, MX (expr))
149
116
150
- function set_casadi_bounds ! (model, sys, pmap)
117
+ function MTK . set_variable_bounds ! (model, sys, pmap, tf )
151
118
@unpack opti, U, V = model
152
119
for (i, u) in enumerate (unknowns (sys))
153
120
if MTK. hasbounds (u)
@@ -163,36 +130,56 @@ function set_casadi_bounds!(model, sys, pmap)
163
130
subject_to! (opti, V. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
164
131
end
165
132
end
133
+ if MTK. symbolic_type (tf) === MTK. ScalarSymbolic () && hasbounds (tf)
134
+ lo, hi = MTK. getbounds (tf)
135
+ subject_to! (opti, model. tₛ >= lo)
136
+ subject_to! (opti, model. tₛ <= hi)
137
+ end
166
138
end
167
139
168
- function add_initial_constraints! (model:: CasADiModel , u0, u0_idxs)
140
+ function MTK . add_initial_constraints! (model:: CasADiModel , u0, u0_idxs)
169
141
@unpack opti, U = model
170
142
for i in u0_idxs
171
143
subject_to! (opti, U. u[i, 1 ] == u0[i])
172
144
end
173
145
end
174
146
175
- function add_user_constraints! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
147
+ function MTK. substitute_model_vars (
148
+ model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
176
149
@unpack opti, U, V, tₛ = model
177
-
178
150
iv = MTK. get_iv (sys)
179
- jconstraints = MTK. get_constraints (sys)
180
- (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
181
-
182
- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
183
- ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
184
- cons_dvs, cons_ps = MTK. process_constraint_system (
185
- jconstraints, Set (unknowns (sys)), parameters (sys), iv; validate = false )
186
-
187
- auxmap = Dict ([u => MTK. default_toterm (MTK. value (u)) for u in cons_dvs])
188
- jconstraints = substitute_casadi_vars (model, sys, pmap, jconstraints; is_free_t, auxmap)
189
- # Manually substitute fixed-t variables
190
- for (i, cons) in enumerate (jconstraints)
191
- consvars = MTK. vars (cons)
192
- for st in consvars
151
+ sts = unknowns (sys)
152
+ cts = MTK. unbound_inputs (sys)
153
+
154
+ x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
155
+ c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
156
+
157
+ exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
158
+ exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
159
+ # tf means different things in different contexts; a [tf] in a cost function
160
+ # should be tₛ, while a x(tf) should translate to x[1]
161
+ if is_free_t
162
+ free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
163
+ [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
164
+ exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
165
+ end
166
+
167
+ exprs = substitute_fixed_t_vars (exprs)
168
+
169
+ # for variables like x(t)
170
+ whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
171
+ [v => V. u[i, :] for (i, v) in enumerate (cts)]])
172
+ exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
173
+ exprs
174
+ end
175
+
176
+ function substitute_fixed_t_vars (exprs)
177
+ for i in 1 : length (exprs)
178
+ subvars = MTK. vars (exprs[i])
179
+ for st in subvars
193
180
MTK. iscall (st) || continue
194
- x = MTK . operation (st)
195
- t = only (MTK . arguments (st))
181
+ x = operation (st)
182
+ t = only (arguments (st))
196
183
MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
197
184
if haskey (stidxmap, x (iv))
198
185
idx = stidxmap[x (iv)]
@@ -201,52 +188,19 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
201
188
idx = ctidxmap[x (iv)]
202
189
cv = V
203
190
end
204
- cons = Symbolics. substitute (cons, Dict (x (t) => cv (t)[idx]))
205
- end
206
-
207
- if cons isa Equation
208
- subject_to! (opti, cons. lhs - cons. rhs == 0 )
209
- elseif cons. relational_op === Symbolics. geq
210
- subject_to! (opti, cons. lhs - cons. rhs ≥ 0 )
211
- else
212
- subject_to! (opti, cons. lhs - cons. rhs ≤ 0 )
191
+ exprs[i] = Symbolics. fast_substitute (exprs[i], Dict (x (t) => cv (t)[idx]))
213
192
end
193
+ jcosts = Symbolics. substitute (jcosts, Dict (x (t) => cv (t)[idx]))
214
194
end
215
195
end
216
196
217
- function add_cost_function! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
218
- @unpack opti, U, V, tₛ = model
219
- jcosts = cost (sys)
220
- if Symbolics. _iszero (jcosts)
221
- minimize! (opti, MX (0 ))
222
- return
223
- end
224
-
225
- iv = MTK. get_iv (sys)
226
- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
227
- ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
228
-
229
- jcosts = substitute_casadi_vars (model, sys, pmap, [jcosts]; is_free_t)[1 ]
230
- # Substitute fixed-time variables.
231
- costvars = MTK. vars (jcosts)
232
- for st in costvars
233
- MTK. iscall (st) || continue
234
- x = operation (st)
235
- t = only (arguments (st))
236
- MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
237
- if haskey (stidxmap, x (iv))
238
- idx = stidxmap[x (iv)]
239
- cv = U
240
- else
241
- idx = ctidxmap[x (iv)]
242
- cv = V
243
- end
244
- jcosts = Symbolics. substitute (jcosts, Dict (x (t) => cv (t)[idx]))
245
- end
197
+ MTK. substitute_differentials (model:: CasADiModel , exprs, args... ) = exprs
246
198
199
+ function MTK. substitute_integral (model:: CasADiModel , exprs)
200
+ @unpack U, opti = model
247
201
dt = U. t[2 ] - U. t[1 ]
248
202
intmap = Dict ()
249
- for int in MTK. collect_applied_operators (jcosts , Symbolics. Integral)
203
+ for int in MTK. collect_applied_operators (exprs , Symbolics. Integral)
250
204
op = MTK. operation (int)
251
205
arg = only (arguments (MTK. value (int)))
252
206
lo, hi = (op. domain. domain. left, op. domain. domain. right)
@@ -255,39 +209,11 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
255
209
# Approximate integral as sum.
256
210
intmap[int] = dt * tₛ * sum (arg)
257
211
end
258
- jcosts = Symbolics. substitute (jcosts, intmap)
259
- jcosts = MTK. value (jcosts)
260
- minimize! (opti, MX (jcosts))
261
- end
262
-
263
- function substitute_casadi_vars (
264
- model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
265
- @unpack opti, U, V, tₛ = model
266
- iv = MTK. get_iv (sys)
267
- sts = unknowns (sys)
268
- cts = MTK. unbound_inputs (sys)
269
-
270
- x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
271
- c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
272
-
273
- exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
274
- exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
275
- # tf means different things in different contexts; a [tf] in a cost function
276
- # should be tₛ, while a x(tf) should translate to x[1]
277
- if is_free_t
278
- free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
279
- [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
280
- exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
281
- end
282
-
283
- # for variables like x(t)
284
- whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
285
- [v => V. u[i, :] for (i, v) in enumerate (cts)]])
286
- exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
287
- exprs
212
+ exprs = map (c -> Symbolics. substitute (c, intmap), exprs)
213
+ exprs = MTK. value .(exprs)
288
214
end
289
215
290
- function add_solve_constraints (prob, tableau)
216
+ function add_solve_constraints! (prob, tableau)
291
217
@unpack A, α, c = tableau
292
218
@unpack model, f, p = prob
293
219
@unpack opti, U, V, tₛ = model
@@ -332,57 +258,29 @@ function add_solve_constraints(prob, tableau)
332
258
solver_opti
333
259
end
334
260
335
- """
336
- solve(prob::CasADiDynamicOptProblem, casadi_solver, ode_solver; plugin_options, solver_options, silent)
337
-
338
- `plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
339
-
340
- NOTE: the solver should be passed in as a string to CasADi. "ipopt"
341
- """
342
- function DiffEqBase. solve (
343
- prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} = " ipopt" ,
344
- tableau_getter = MTK. constructDefault; plugin_options:: Dict = Dict (),
345
- solver_options:: Dict = Dict (), silent = false )
346
- @unpack model, u0, p, tspan, f = prob
347
- tableau = tableau_getter ()
348
- @unpack opti, U, V, tₛ = model
349
-
261
+ function MTK. prepare_solver ()
350
262
opti = add_solve_constraints (prob, tableau)
351
- silent && (solver_options[" print_level" ] = 0 )
352
263
solver! (opti, " $solver " , plugin_options, solver_options)
264
+ end
265
+ function MTK. get_U_values ()
266
+ U_vals = value_getter (U. u)
267
+ size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
268
+ U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 )] for j in 1 : length (ts)]
269
+ end
270
+ function MTK. get_V_values ()
271
+ end
272
+ function MTK. get_t_values ()
273
+ ts = value_getter (tₛ) * U. t
274
+ end
353
275
354
- failed = false
355
- value_getter = nothing
356
- sol = nothing
276
+ function MTK. optimize_model! ()
357
277
try
358
278
sol = CasADi. solve! (opti)
359
279
value_getter = x -> CasADi. value (sol, x)
360
280
catch ErrorException
361
281
value_getter = x -> CasADi. debug_value (opti, x)
362
282
failed = true
363
283
end
364
-
365
- ts = value_getter (tₛ) * U. t
366
- U_vals = value_getter (U. u)
367
- size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
368
- U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 )] for j in 1 : length (ts)]
369
- ode_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, U_vals)
370
-
371
- input_sol = nothing
372
- if prod (size (V. u)) != 0
373
- V_vals = value_getter (V. u)
374
- size (V_vals, 2 ) == 1 && (V_vals = V_vals' )
375
- V_vals = [[V_vals[i, j] for i in 1 : size (V_vals, 1 )] for j in 1 : length (ts)]
376
- input_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, V_vals)
377
- end
378
-
379
- if failed
380
- ode_sol = SciMLBase. solution_new_retcode (
381
- ode_sol, SciMLBase. ReturnCode. ConvergenceFailure)
382
- ! isnothing (input_sol) && (input_sol = SciMLBase. solution_new_retcode (
383
- input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
384
- end
385
-
386
- DynamicOptSolution (model, ode_sol, input_sol)
387
284
end
285
+ MTK. successful_solve () = true
388
286
end
0 commit comments