@@ -3,18 +3,18 @@ using ModelingToolkit
33using JuMP, InfiniteOpt
44using DiffEqDevTools, DiffEqBase
55
6- struct JuMPProblem {uType, tType, isinplace, P, F, K} < :
6+ struct JuMPControlProblem {uType, tType, isinplace, P, F, K} < :
77 AbstractODEProblem{uType, tType, isinplace}
88 f:: F
99 u0:: uType
10- tspan
10+ tspan:: tType
1111 p
1212 model
1313 kwargs
1414end
1515
1616"""
17- JuMPProblem (sys::ODESystem, u0, tspan, p; dt)
17+ JuMPControlProblem (sys::ODESystem, u0, tspan, p; dt)
1818
1919Convert an ODESystem representing an optimal control system into a JuMP model
2020for solving using optimization. Must provide `dt` for determining the length
@@ -28,7 +28,7 @@ The constraints are:
2828- The set of user constraints passed to the ODESystem via `constraints`
2929- The solver constraints that encode the time-stepping used by the solver
3030"""
31- function JuMPProblem (sys:: ODESystem , u0map, tspan, pmap; dt = error (" dt must be provided for JuMPProblem." ), solver = :Tsit5 )
31+ function JuMPControlProblem (sys:: ODESystem , u0map, tspan, pmap; dt = error (" dt must be provided for JuMPProblem." ), solver = :Tsit5 )
3232 ts = tspan[1 ]
3333 te = tspan[2 ]
3434 steps = ts: dt: te
@@ -54,7 +54,7 @@ function JuMPProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be
5454 add_user_constraints! (model, sys)
5555 add_solve_constraints! (model)
5656
57- JuMPProblem {iip} (f, u0, tspan, p, model; kwargs... )
57+ JuMPControlProblem {iip} (f, u0, tspan, p, model; kwargs... )
5858end
5959
6060function add_jump_cost_function! (model, sys)
@@ -118,20 +118,67 @@ function add_user_constraints!(model, sys, u0map)
118118 # Add initial constraints.
119119end
120120
121- function add_solve_constraints! (model, tsteps, solver)
122- tableau = fetch_tableau (solver)
123-
124- for (i, t) in collect (enumerate (tsteps))
121+ function add_solve_constraints! (prob, talbeau, f, tsteps)
122+ A = tableau. A
123+ α = tableau. α
124+ c = tableau. c
125+ model = prob. model
126+ p = prob. p
127+ dt = step (tsteps)
128+
129+ if is_explicit (tableau)
130+ K = Any[]
131+ for t in tsteps
132+ for (i, h) in enumerate (c)
133+ ΔU = sum ([A[i, j] * K[j] for j in 1 : i- 1 ])
134+ Kₙ = f (U + ΔU* dt, p, t + h* dt)
135+ push! (K, Kₙ)
136+ end
137+ @constraint (model, U (t) + dot (α, K) == U (t + dt))
138+ empty! (K)
139+ end
140+ else
141+ @variable (model, K[1 : length (a)], Infinite (t), start = tsteps[1 ])
142+ for t in tsteps
143+ ΔUs = A * K (t)
144+ for (i, h) in enumerate (c)
145+ ΔU = ΔUs[i]
146+ @constraint (model, K[i](t) == f (U + ΔU* dt, p, t + h* dt))
147+ end
148+ @constraint (model, U (t) + dot (α, K (t)) == U (t + dt))
149+ end
125150 end
126151end
127152
153+ is_explicit (tableau) = tableau isa DiffEqDevTools. ExplicitRKTableau
154+
128155"""
129- Solve JuMPProblem. Takes in a symbol representing the solver.
130156"""
131- function solve (prob:: JuMPProblem , solver_sym:: Symbol )
157+ struct JuMPControlSolution
158+ model
159+ sol:: ODESolution
160+ end
161+
162+ """
163+ Solve JuMPProblem. Takes in a symbol representing the solver. Acceptable solvers may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/.
164+ Note that the symbol may be different than the typical
165+ name of the solver, e.g. :Tsitouras5 rather than Tsit5.
166+ """
167+ function solve (prob:: JuMPProblem , jump_solver, ode_solver:: Symbol )
132168 model = prob. model
169+ f = prob. f
133170 tableau_getter = Symbol (:construct , solver)
134171 @eval tableau = $ tableau_getter ()
135- add_solve_constraints! (model, tableau)
172+ ts = prob. tspan[1 ]: dt: prob. tspan[2 ]
173+ add_solve_constraints! (model, ts, tableau, f)
174+
175+ set_optimizer (model, solver)
176+ optimize! (model)
177+
178+ if is_solved_and_feasible (model)
179+ sol = DiffEqBase. build_solution (prob, ode_solver, ts, value (U))
180+ JuMPControlSolution (model, sol)
181+ end
136182end
183+
137184end
0 commit comments