Skip to content

Commit 3bdfc46

Browse files
committed
refactor: add definitions of shared functions to optimal_control_interface.jl
1 parent 60c95b9 commit 3bdfc46

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
8787
CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
8888
end
8989

90+
MTK.generate_U(model, dims) = 1
91+
MTK.generate_V(model, dims) = 1
92+
MTK.generate_timescale(model, dims) = 1
93+
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.opti()
94+
9095
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
9196
ctrls = MTK.unbound_inputs(sys)
9297
states = unknowns(sys)

src/systems/optimal_control_interface.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <:
22
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
33

4+
abstract type AbstractCollocation end
5+
46
struct DynamicOptSolution
57
model::Any
68
sol::ODESolution
@@ -181,3 +183,148 @@ function process_tspan(tspan, dt, steps)
181183
return length(tspan[1]:dt:tspan[2]), false
182184
end
183185
end
186+
187+
function process_DynamicOptProblem(prob_type::AbstractDynamicOptProblem, model_type, sys::ODESystem, u0map, tspan, pmap;
188+
dt = nothing,
189+
steps = nothing,
190+
guesses = Dict(), kwargs...)
191+
192+
MTK.warn_overdetermined(sys, u0map)
193+
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
194+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
195+
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
196+
197+
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
198+
u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map])
199+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
200+
[stidxmap[MTK.default_toterm(k)] for (k, v) in u0map]
201+
pmap = Dict{Any, Any}(pmap)
202+
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
203+
204+
ctrls = MTK.unbound_inputs(sys)
205+
states = unknowns(sys)
206+
207+
model = generate_internal_model(model_type)
208+
U = generate_U(model, u0)
209+
V = generate_V()
210+
tₛ = generate_timescale()
211+
fullmodel = model_type(model, U, V, tₛ)
212+
213+
set_variable_bounds!(fullmodel, sys, pmap)
214+
add_cost_function!(fullmodel, sys, tspan, pmap; is_free_t)
215+
add_user_constraints!(fullmodel, sys, tspan, pmap; is_free_t)
216+
add_initial_constraints!(fullmodel, u0, u0_idxs)
217+
218+
prob_type(f, u0, tspan, p, fullmodel, kwargs...)
219+
end
220+
221+
function add_cost_function!()
222+
jcosts = copy(MTK.get_costs(sys))
223+
consolidate = MTK.get_consolidate(sys)
224+
if isnothing(jcosts) || isempty(jcosts)
225+
minimize!(opti, MX(0))
226+
return
227+
end
228+
229+
jcosts = substitute_model_vars(model, sys, pmap, jcosts; is_free_t)
230+
jcosts = substitute_free_final_vars(model, sys, pmap, jcosts; is_free_t)
231+
jcosts = substitute_fixed_t_vars(model, sys, pmap, jcosts; is_free_t)
232+
jcosts = substitute_integral()
233+
end
234+
235+
function add_user_constraints!()
236+
conssys = MTK.get_constraintsystem(sys)
237+
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
238+
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
239+
240+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
241+
jconstraints = substitute_model_vars(model, sys, pmap, jconstraints; auxmap, is_free_t)
242+
243+
for c in jconstraints
244+
if cons isa Equation
245+
add_constraint!()
246+
elseif cons.relational_op === Symbolics.geq
247+
add_constraint!()
248+
else
249+
add_constraint!()
250+
end
251+
end
252+
end
253+
254+
function generate_U end
255+
function generate_V end
256+
function generate_timescale end
257+
258+
function add_initial_constraints! end
259+
function add_constraint! end
260+
261+
function add_collocation_solve_constraints!(prob, tableau)
262+
nᵤ = size(U.u, 1)
263+
nᵥ = size(V.u, 1)
264+
265+
if is_explicit(tableau)
266+
K = MX[]
267+
for k in 1:(length(tsteps) - 1)
268+
τ = tsteps[k]
269+
for (i, h) in enumerate(c)
270+
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ)))
271+
Uₙ = U.u[:, k] + ΔU * dt
272+
Vₙ = V.u[:, k]
273+
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
274+
push!(K, Kₙ)
275+
end
276+
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
277+
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1])
278+
empty!(K)
279+
end
280+
else
281+
for k in 1:(length(tsteps) - 1)
282+
τ = tsteps[k]
283+
# Kᵢ = generate_K()
284+
Kᵢ = variable!(solver_opti, nᵤ, length(α))
285+
ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method
286+
for (i, h) in enumerate(c)
287+
ΔU = ΔUs[i, :]'
288+
Uₙ = U.u[:, k] + ΔU * dt
289+
Vₙ = V.u[:, k]
290+
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt))
291+
end
292+
ΔU_tot = dt * (Kᵢ * α)
293+
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1])
294+
end
295+
end
296+
end
297+
298+
function add_equational_solve_constraints!()
299+
diff_eqs = substitute_differentials()
300+
add_constraint!()
301+
302+
alg_eqs = substitute_model_vars()
303+
add_constraint!()
304+
end
305+
306+
"""
307+
Add the solve constraints, set the solver (Ipopt, e.g.)
308+
"""
309+
function prepare_solver end
310+
311+
function DiffEqBase.solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation)
312+
#add_solve_constraints!(prob, solver)
313+
solver = prepare_solver(prob, solver)
314+
sol = solve_prob(prob, solver)
315+
316+
ts = get_t_values(sol)
317+
Us = get_U_values(sol)
318+
Vs = get_V_values(sol)
319+
320+
ode_sol = DiffEqBase.build_solution(prob, solver, ts, Us)
321+
input_sol = DiffEqBase.build_solution(prob, solver, ts, Vs)
322+
323+
if successful_solve(model)
324+
ode_sol = SciMLBase.solution_new_retcode(
325+
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
326+
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
327+
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
328+
end
329+
DynamicOptSolution(model, ode_sol, input_sol)
330+
end

0 commit comments

Comments
 (0)