@@ -18,12 +18,17 @@ struct MXLinearInterpolation
18
18
dt:: Float64
19
19
end
20
20
21
- struct CasADiModel
22
- opti :: Opti
21
+ mutable struct CasADiModel
22
+ model :: Opti
23
23
U:: MXLinearInterpolation
24
24
V:: MXLinearInterpolation
25
25
tₛ:: MX
26
26
is_free_final:: Bool
27
+ solver_opti:: Union{Nothing, Opti}
28
+
29
+ function CasADiModel (opti, U, V, tₛ, is_free_final, solver_opti = nothing )
30
+ new (opti, U, V, tₛ, is_free_final, solver_opti)
31
+ end
27
32
end
28
33
29
34
struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} < :
@@ -74,24 +79,27 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
74
79
dt = nothing ,
75
80
steps = nothing ,
76
81
guesses = Dict (), kwargs... )
77
- process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
82
+ MTK . process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
78
83
end
79
84
80
- MTK. generate_internal_model (:: Type{CasADiModel} ) = CasADi. opti ()
85
+ MTK. generate_internal_model (:: Type{CasADiModel} ) = CasADi. Opti ()
86
+ MTK. generate_time_variable! (opti:: Opti , args... ) = nothing
81
87
82
- function MTK. generate_state_variable (model:: Opti , u0, ns, nt, tsteps)
88
+ function MTK. generate_state_variable! (model:: Opti , u0, ns, tsteps)
89
+ nt = length (tsteps)
83
90
U = CasADi. variable! (model, ns, nt)
84
- set_initial! (opti , U, DM (repeat (u0, 1 , steps )))
91
+ set_initial! (model , U, DM (repeat (u0, 1 , nt )))
85
92
MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
86
93
end
87
94
88
- function MTK. generate_input_variable (model:: Opti , c0, nc, nt, tsteps)
95
+ function MTK. generate_input_variable! (model:: Opti , c0, nc, tsteps)
96
+ nt = length (tsteps)
89
97
V = CasADi. variable! (model, nc, nt)
90
- ! isempty (c0) && set_initial! (opti , V, DM (repeat (c0, 1 , steps )))
98
+ ! isempty (c0) && set_initial! (model , V, DM (repeat (c0, 1 , nt )))
91
99
MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
92
100
end
93
101
94
- function MTK. generate_timescale (model:: Opti , guess, is_free_t)
102
+ function MTK. generate_timescale! (model:: Opti , guess, is_free_t)
95
103
if is_free_t
96
104
tₛ = variable! (model)
97
105
set_initial! (model, tₛ, guess)
@@ -102,78 +110,73 @@ function MTK.generate_timescale(model::Opti, guess, is_free_t)
102
110
end
103
111
end
104
112
105
- function MTK. add_constraint! (model:: CasADiModel , expr)
106
- @unpack opti = model
107
- if cons isa Equation
108
- subject_to! (opti, expr. lhs - expr. rhs == 0 )
109
- elseif cons. relational_op === Symbolics. geq
110
- subject_to! (opti, expr. lhs - expr. rhs ≥ 0 )
113
+ function MTK. add_constraint! (m:: CasADiModel , expr)
114
+ if expr isa Equation
115
+ subject_to! (m. model, expr. lhs - expr. rhs == 0 )
116
+ elseif expr. relational_op === Symbolics. geq
117
+ subject_to! (m. model, expr. lhs - expr. rhs ≥ 0 )
111
118
else
112
- subject_to! (opti , expr. lhs - expr. rhs ≤ 0 )
119
+ subject_to! (m . model , expr. lhs - expr. rhs ≤ 0 )
113
120
end
114
121
end
115
- MTK. set_objective! (model :: CasADiModel , expr) = minimize! (model . opti , MX (expr))
122
+ MTK. set_objective! (m :: CasADiModel , expr) = minimize! (m . model , MX (expr))
116
123
117
- function MTK. set_variable_bounds! (model , sys, pmap, tf)
118
- @unpack opti , U, V = model
124
+ function MTK. set_variable_bounds! (m :: CasADiModel , sys, pmap, tf)
125
+ @unpack model , U, tₛ, V = m
119
126
for (i, u) in enumerate (unknowns (sys))
120
127
if MTK. hasbounds (u)
121
128
lo, hi = MTK. getbounds (u)
122
- subject_to! (opti , Symbolics. fast_substitute (lo, pmap) <= U. u[i, :])
123
- subject_to! (opti , U. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
129
+ subject_to! (model , Symbolics. fixpoint_sub (lo, pmap) <= U. u[i, :])
130
+ subject_to! (model , U. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
124
131
end
125
132
end
126
133
for (i, v) in enumerate (MTK. unbound_inputs (sys))
127
134
if MTK. hasbounds (v)
128
135
lo, hi = MTK. getbounds (v)
129
- subject_to! (opti , Symbolics. fast_substitute (lo, pmap) <= V. u[i, :])
130
- subject_to! (opti , V. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
136
+ subject_to! (model , Symbolics. fixpoint_sub (lo, pmap) <= V. u[i, :])
137
+ subject_to! (model , V. u[i, :] <= Symbolics. fixpoint_sub (hi, pmap))
131
138
end
132
139
end
133
140
if MTK. symbolic_type (tf) === MTK. ScalarSymbolic () && hasbounds (tf)
134
141
lo, hi = MTK. getbounds (tf)
135
- subject_to! (opti, model . tₛ >= lo)
136
- subject_to! (opti, model . tₛ <= hi)
142
+ subject_to! (model, tₛ >= lo)
143
+ subject_to! (model, tₛ <= hi)
137
144
end
138
145
end
139
146
140
- function MTK. add_initial_constraints! (model :: CasADiModel , u0, u0_idxs)
141
- @unpack opti , U = model
147
+ function MTK. add_initial_constraints! (m :: CasADiModel , u0, u0_idxs, args ... )
148
+ @unpack model , U = m
142
149
for i in u0_idxs
143
- subject_to! (opti , U. u[i, 1 ] == u0[i])
150
+ subject_to! (model , U. u[i, 1 ] == u0[i])
144
151
end
145
152
end
146
153
147
- function MTK. substitute_model_vars (
148
- model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
149
- @unpack opti, U, V, tₛ = model
154
+ function MTK. substitute_model_vars (m:: CasADiModel , sys, exprs, tspan)
155
+ @unpack model, U, V, tₛ = m
150
156
iv = MTK. get_iv (sys)
151
157
sts = unknowns (sys)
152
158
cts = MTK. unbound_inputs (sys)
153
-
154
159
x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
155
160
c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
156
-
157
- exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
158
- exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
159
- # tf means different things in different contexts; a [tf] in a cost function
160
- # should be tₛ, while a x(tf) should translate to x[1]
161
- if is_free_t
162
- free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
163
- [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
161
+ (ti, tf) = tspan
162
+ if MTK. is_free_final (m)
163
+ _tf = tₛ + ti
164
+ exprs = map (c -> Symbolics. fast_substitute (c, Dict (tf => _tf)), exprs)
165
+ free_t_map = Dict ([[x (_tf) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
166
+ [c (_tf) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
164
167
exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
165
168
end
166
169
167
- exprs = substitute_fixed_t_vars (exprs)
168
-
169
- # for variables like x(t)
170
+ exprs = substitute_fixed_t_vars (m, sys, exprs)
170
171
whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
171
172
[v => V. u[i, :] for (i, v) in enumerate (cts)]])
172
173
exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
173
- exprs
174
174
end
175
175
176
- function substitute_fixed_t_vars (exprs)
176
+ function substitute_fixed_t_vars (model:: CasADiModel , sys, exprs)
177
+ stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
178
+ ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
179
+ iv = MTK. get_iv (sys)
177
180
for i in 1 : length (exprs)
178
181
subvars = MTK. vars (exprs[i])
179
182
for st in subvars
@@ -183,27 +186,28 @@ function substitute_fixed_t_vars(exprs)
183
186
MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
184
187
if haskey (stidxmap, x (iv))
185
188
idx = stidxmap[x (iv)]
186
- cv = U
189
+ cv = model . U
187
190
else
188
191
idx = ctidxmap[x (iv)]
189
- cv = V
192
+ cv = model . V
190
193
end
191
194
exprs[i] = Symbolics. fast_substitute (exprs[i], Dict (x (t) => cv (t)[idx]))
192
195
end
193
196
jcosts = Symbolics. substitute (jcosts, Dict (x (t) => cv (t)[idx]))
194
197
end
198
+ exprs
195
199
end
196
200
197
- MTK. substitute_differentials (model:: CasADiModel , exprs, args ... ) = exprs
201
+ MTK. substitute_differentials (model:: CasADiModel , sys, eqs ) = exprs
198
202
199
- function MTK. substitute_integral (model :: CasADiModel , exprs)
200
- @unpack U, opti = model
203
+ function MTK. substitute_integral (m :: CasADiModel , exprs, tspan )
204
+ @unpack U, model, tₛ = m
201
205
dt = U. t[2 ] - U. t[1 ]
202
206
intmap = Dict ()
203
207
for int in MTK. collect_applied_operators (exprs, Symbolics. Integral)
204
208
op = MTK. operation (int)
205
209
arg = only (arguments (MTK. value (int)))
206
- lo, hi = ( op. domain. domain. left, op. domain. domain. right)
210
+ lo, hi = MTK . value .(( op. domain. domain. left, op. domain. domain. right) )
207
211
! isequal ((lo, hi), tspan) &&
208
212
error (" Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem." )
209
213
# Approximate integral as sum.
@@ -213,11 +217,11 @@ function MTK.substitute_integral(model::CasADiModel, exprs)
213
217
exprs = MTK. value .(exprs)
214
218
end
215
219
216
- function add_solve_constraints! (prob, tableau)
220
+ function add_solve_constraints! (prob:: CasADiDynamicOptProblem , tableau)
217
221
@unpack A, α, c = tableau
218
222
@unpack model, f, p = prob
219
- @unpack opti , U, V, tₛ = model
220
- solver_opti = copy (opti )
223
+ @unpack model , U, V, tₛ = model
224
+ solver_opti = copy (model )
221
225
222
226
tsteps = U. t
223
227
dt = tsteps[2 ] - tsteps[1 ]
@@ -258,29 +262,56 @@ function add_solve_constraints!(prob, tableau)
258
262
solver_opti
259
263
end
260
264
261
- function MTK. prepare_solver ()
262
- opti = add_solve_constraints (prob, tableau)
263
- solver! (opti, " $solver " , plugin_options, solver_options)
265
+ """
266
+ CasADi Collocation solver.
267
+ - solver: an optimization solver such as Ipopt. Should be given as a string or symbol in all lowercase, e.g. "ipopt"
268
+ - tableau: An ODE RK tableau. Load a tableau by calling a function like `constructRK4` and may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. If this argument is not passed in, the solver will default to Radau second order.
269
+ """
270
+ struct CasADiCollocation <: AbstractCollocation
271
+ solver:: Union{String, Symbol}
272
+ tableau:: DiffEqBase.ODERKTableau
273
+ end
274
+ MTK. CasADiCollocation (solver, tableau = MTK. constructDefault ()) = CasADiCollocation (solver, tableau)
275
+
276
+ function MTK. prepare_and_optimize! (prob:: CasADiDynamicOptProblem , solver:: CasADiCollocation ; verbose = false , solver_options = Dict (), plugin_options = Dict (), kwargs... )
277
+ solver_opti = add_solve_constraints! (prob, solver. tableau)
278
+ verbose || (solver_options[" print_level" ] = 0 )
279
+ solver! (solver_opti, " $(solver. solver) " , plugin_options, solver_options)
280
+ try
281
+ CasADi. solve! (solver_opti)
282
+ catch ErrorException
283
+ end
284
+ prob. model. solver_opti = solver_opti
264
285
end
265
- function MTK. get_U_values ()
266
- U_vals = value_getter (U. u)
286
+
287
+ function MTK. get_U_values (model:: CasADiModel )
288
+ value_getter = MTK. successful_solve (model) ? CasADi. debug_value : CasADi. value
289
+ (nu, nt) = size (model. U. u)
290
+ U_vals = value_getter (model. solver_opti, model. U. u)
267
291
size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
268
- U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 ) ] for j in 1 : length (ts) ]
292
+ U_vals = [[U_vals[i, j] for i in 1 : nu ] for j in 1 : nt ]
269
293
end
270
- function MTK. get_V_values ()
294
+
295
+ function MTK. get_V_values (model:: CasADiModel )
296
+ value_getter = MTK. successful_solve (model) ? CasADi. debug_value : CasADi. value
297
+ (nu, nt) = size (model. V. u)
298
+ if nu* nt != 0
299
+ V_vals = value_getter (model. solver_opti, model. V. u)
300
+ size (V_vals, 2 ) == 1 && (V_vals = V_vals' )
301
+ V_vals = [[V_vals[i, j] for i in 1 : nu] for j in 1 : nt]
302
+ else
303
+ nothing
304
+ end
271
305
end
272
- function MTK. get_t_values ()
273
- ts = value_getter (tₛ) * U. t
306
+
307
+ function MTK. get_t_values (model:: CasADiModel )
308
+ value_getter = MTK. successful_solve (model) ? CasADi. debug_value : CasADi. value
309
+ ts = value_getter (model. solver_opti, model. tₛ) .* model. U. t
274
310
end
275
311
276
- function MTK. optimize_model! ()
277
- try
278
- sol = CasADi. solve! (opti)
279
- value_getter = x -> CasADi. value (sol, x)
280
- catch ErrorException
281
- value_getter = x -> CasADi. debug_value (opti, x)
282
- failed = true
283
- end
312
+ function MTK. successful_solve (m:: CasADiModel )
313
+ isnothing (m. solver_opti) && return false
314
+ retcode = CasADi. return_status (m. solver_opti)
315
+ retcode == " Solve_Succeeded" || retcode == " Solved_To_Acceptable_Level"
284
316
end
285
- MTK. successful_solve () = true
286
317
end
0 commit comments