@@ -58,7 +58,7 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
5858 f, u0, p = MTK. process_SciMLProblem (ControlFunction, sys, _u0map, pmap;
5959 t = tspan != = nothing ? tspan[1 ] : tspan, kwargs... )
6060
61- model = init_model (sys, tspan[1 ]: dt: tspan[2 ], u0map, u0)
61+ model = init_model (sys, tspan[1 ]: dt: tspan[2 ], u0map, pmap, u0)
6262 JuMPControlProblem (f, u0, tspan, p, model, kwargs... )
6363end
6464
@@ -80,22 +80,23 @@ function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
8080 f, u0, p = MTK. process_SciMLProblem (ControlFunction, sys, _u0map, pmap;
8181 t = tspan != = nothing ? tspan[1 ] : tspan, kwargs... )
8282
83- model = init_model (sys, tspan[1 ]: dt: tspan[2 ], u0map, u0)
83+ model = init_model (sys, tspan[1 ]: dt: tspan[2 ], u0map, pmap, u0)
8484 add_infopt_solve_constraints! (model, sys, pmap)
8585 InfiniteOptControlProblem (f, u0, tspan, p, model, kwargs... )
8686end
8787
88- function init_model (sys, tsteps, u0map, u0)
88+ function init_model (sys, tsteps, u0map, pmap, u0)
8989 ctrls = MTK. unbound_inputs (sys)
9090 states = unknowns (sys)
9191 model = InfiniteModel ()
92+
9293 @infinite_parameter (model, t in [tsteps[1 ], tsteps[end ]], num_supports= length (tsteps))
9394 @variable (model, U[i = 1 : length (states)], Infinite (t))
9495 @variable (model, V[1 : length (ctrls)], Infinite (t))
9596
9697 set_bounds! (model, sys)
97- add_jump_cost_function! (model, sys)
98- add_user_constraints! (model, sys)
98+ add_jump_cost_function! (model, sys, (tsteps[ 1 ], tsteps[ 2 ]), pmap )
99+ add_user_constraints! (model, sys, pmap )
99100
100101 stidxmap = Dict ([v => i for (i, v) in enumerate (states)])
101102 u0_idxs = has_alg_eqs (sys) ? collect (1 : length (states)) :
@@ -120,63 +121,35 @@ function set_bounds!(model, sys)
120121 end
121122end
122123
123- function add_jump_cost_function! (model:: InfiniteModel , sys)
124+ function add_jump_cost_function! (model:: InfiniteModel , sys, tspan, pmap )
124125 jcosts = MTK. get_costs (sys)
125126 consolidate = MTK. get_consolidate (sys)
126127 if isnothing (jcosts) || isempty (jcosts)
127128 @objective (model, Min, 0 )
128129 return
129130 end
130- iv = MTK. get_iv (sys)
131-
132- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
133- cidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
134-
135- for st in unknowns (sys)
136- x = operation (st)
137- t = only (arguments (st))
138- idx = stidxmap[x (iv)]
139- subval = isequal (t, iv) ? model[:U ][idx] : model[:U ][idx](t)
140- jcosts = map (c -> Symbolics. substitute (c, Dict (x (t) => subval)), jcosts)
141- end
131+ jcosts = substitute_jump_vars (model, sys, pmap, jcosts)
142132
143- for ct in MTK. unbound_inputs (sys)
144- p = operation (ct)
145- t = only (arguments (ct))
146- idx = cidxmap[p (iv)]
147- subval = isequal (t, iv) ? model[:V ][idx] : model[:V ][idx](t)
148- jcosts = map (c -> Symbolics. substitute (c, Dict (p (t) => subval)), jcosts)
133+ # Substitute integral
134+ iv = MTK. get_iv (sys)
135+ jcosts = map (c -> Symbolics. substitute (c, ∫ => Symbolics. Integral (iv in tspan)), jcosts)
136+ intmap = Dict ()
137+
138+ for int in MTK. collect_applied_operators (jcosts, Symbolics. Integral)
139+ arg = only (arguments (MTK. value (int)))
140+ lower_bound, upper_bound = (int. domain. domain. left, int. domain. domain. right)
141+ intmap[int] = InfiniteOpt.∫ (arg, iv; lower_bound, upper_bound)
149142 end
150-
143+ jcosts = map (c -> Symbolics . substitute (c, intmap), jcosts)
151144 @objective (model, Min, consolidate (jcosts))
152145end
153146
154- function add_user_constraints! (model:: InfiniteModel , sys)
147+ function add_user_constraints! (model:: InfiniteModel , sys, pmap )
155148 conssys = MTK. get_constraintsystem (sys)
156149 jconstraints = isnothing (conssys) ? nothing : MTK. get_constraints (conssys)
157150 (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
158151
159- iv = MTK. get_iv (sys)
160- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
161- cidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
162-
163- for st in unknowns (conssys)
164- x = operation (st)
165- t = only (arguments (st))
166- idx = stidxmap[x (iv)]
167- subval = isequal (t, iv) ? model[:U ][idx] : model[:U ][idx](t)
168- jconstraints = map (c -> Symbolics. substitute (c, Dict (x (t) => subval)), jconstraints)
169- end
170-
171- for ct in MTK. unbound_inputs (sys)
172- p = operation (ct)
173- t = only (arguments (ct))
174- idx = cidxmap[p (iv)]
175- subval = isequal (t, iv) ? model[:V ][idx] : model[:V ][idx](t)
176- jconstraints = map (
177- c -> Symbolics. substitute (jconstraints, Dict (p (t) => subval)), jconstriants)
178- end
179-
152+ jconstraints = substitute_jump_vars (model, sys, pmap, jconstraints)
180153 for (i, cons) in enumerate (jconstraints)
181154 if cons isa Equation
182155 @constraint (model, cons. lhs - cons. rhs== 0 , base_name= " user[$i ]" )
@@ -193,31 +166,41 @@ function add_initial_constraints!(model::InfiniteModel, u0, u0_idxs, ts)
193166 @constraint (model, initial[i in u0_idxs], U[i](ts)== u0[i])
194167end
195168
196- is_explicit (tableau) = tableau isa DiffEqDevTools. ExplicitRKTableau
197-
198- function add_infopt_solve_constraints! (model:: InfiniteModel , sys, pmap)
169+ function substitute_jump_vars (model, sys, pmap, exprs)
199170 iv = MTK. get_iv (sys)
200- t = model[:t ]
171+ sts = unknowns (sys)
172+ cts = MTK. unbound_inputs (sys)
201173 U = model[:U ]
202174 V = model[:V ]
175+ # for variables like x(t)
176+ whole_interval_map = Dict ([[v => U[i] for (i, v) in enumerate (sts)]; [v => V[i] for (i, v) in enumerate (cts)]])
177+ exprs = map (c -> Symbolics. substitute (c, whole_interval_map), exprs)
178+
179+ # for variables like x(1.0)
180+ x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
181+ c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
182+ fixed_t_map = Dict ([[x_ops[i] => U[i] for i in 1 : length (U)]; [c_ops[i] => V[i] for i in 1 : length (V)]])
183+ exprs = map (c -> Symbolics. substitute (c, fixed_t_map), exprs)
184+
185+ exprs = map (c -> Symbolics. substitute (c, Dict (pmap)), exprs)
186+ exprs
187+ end
203188
204- stmap = Dict ([v => U[i] for (i, v) in enumerate (unknowns (sys))])
205- ctrlmap = Dict ([v => V[i] for (i, v) in enumerate (MTK. unbound_inputs (sys))])
206- submap = merge (stmap, ctrlmap, Dict (pmap))
189+ is_explicit (tableau) = tableau isa DiffEqDevTools. ExplicitRKTableau
207190
191+ function add_infopt_solve_constraints! (model:: InfiniteModel , sys, pmap)
208192 # Differential equations
209- diff_eqs = diff_equations (sys)
210- D = Differential (iv)
193+ U = model[:U ]
194+ t = model[:t ]
195+ D = Differential (MTK. get_iv (sys))
211196 diffsubmap = Dict ([D (U[i]) => ∂ (U[i], t) for i in 1 : length (U)])
212- for u in unknowns (sys)
213- diff_eqs = map (e -> Symbolics. substitute (e, submap), diff_eqs)
214- diff_eqs = map (e -> Symbolics. substitute (e, diffsubmap), diff_eqs)
215- end
197+
198+ diff_eqs = substitute_jump_vars (model, sys, pmap, diff_equations (sys))
199+ diff_eqs = map (e -> Symbolics. substitute (e, diffsubmap), diff_eqs)
216200 @constraint (model, D[i = 1 : length (diff_eqs)], diff_eqs[i]. lhs== diff_eqs[i]. rhs)
217201
218202 # Algebraic equations
219- alg_eqs = alg_equations (sys)
220- alg_eqs = map (e -> Symbolics. substitute (e, submap), alg_eqs)
203+ alg_eqs = substitute_jump_vars (model, sys, pmap, alg_equations (sys))
221204 @constraint (model, A[i = 1 : length (alg_eqs)], alg_eqs[i]. lhs== alg_eqs[i]. rhs)
222205end
223206
306289`derivative_method` kwarg refers to the method used by InfiniteOpt to compute derivatives. The list of possible options can be found at https://infiniteopt.github.io/InfiniteOpt.jl/stable/guide/derivative/. Defaults to FiniteDifference(Backward()).
307290"""
308291function DiffEqBase. solve (prob:: InfiniteOptControlProblem , jump_solver;
309- derivative_method = InfiniteOpt. FiniteDifference (Backward ()))
292+ derivative_method = InfiniteOpt. FiniteDifference (Backward ()), silent = false )
293+ model = prob. model
310294 silent && set_silent (model)
311- set_derivative_method (prob . model[:t ], derivative_method)
295+ set_derivative_method (model[:t ], derivative_method)
312296 _solve (prob, jump_solver, derivative_method)
313297end
314298
0 commit comments