Skip to content

Commit 2ccdbba

Browse files
committed
refactor: add definitions of shared functions to optimal_control_interface.jl
1 parent 66cc813 commit 2ccdbba

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
9090
CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
9191
end
9292

93+
MTK.generate_U(model, dims) = 1
94+
MTK.generate_V(model, dims) = 1
95+
MTK.generate_timescale(model, dims) = 1
96+
MTK.generate_internal_model(::Type{CasADiModel}) = CasADi.opti()
97+
9398
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
9499
ctrls = MTK.unbound_inputs(sys)
95100
states = unknowns(sys)

src/systems/optimal_control_interface.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <:
22
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
33

4+
abstract type AbstractCollocation end
5+
46
struct DynamicOptSolution
57
model::Any
68
sol::ODESolution
@@ -148,3 +150,148 @@ function process_tspan(tspan, dt, steps)
148150
return length(tspan[1]:dt:tspan[2]), false
149151
end
150152
end
153+
154+
function process_DynamicOptProblem(prob_type::AbstractDynamicOptProblem, model_type, sys::ODESystem, u0map, tspan, pmap;
155+
dt = nothing,
156+
steps = nothing,
157+
guesses = Dict(), kwargs...)
158+
159+
MTK.warn_overdetermined(sys, u0map)
160+
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
161+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
162+
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
163+
164+
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
165+
u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map])
166+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
167+
[stidxmap[MTK.default_toterm(k)] for (k, v) in u0map]
168+
pmap = Dict{Any, Any}(pmap)
169+
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
170+
171+
ctrls = MTK.unbound_inputs(sys)
172+
states = unknowns(sys)
173+
174+
model = generate_internal_model(model_type)
175+
U = generate_U(model, u0)
176+
V = generate_V()
177+
tₛ = generate_timescale()
178+
fullmodel = model_type(model, U, V, tₛ)
179+
180+
set_variable_bounds!(fullmodel, sys, pmap)
181+
add_cost_function!(fullmodel, sys, tspan, pmap; is_free_t)
182+
add_user_constraints!(fullmodel, sys, tspan, pmap; is_free_t)
183+
add_initial_constraints!(fullmodel, u0, u0_idxs)
184+
185+
prob_type(f, u0, tspan, p, fullmodel, kwargs...)
186+
end
187+
188+
function add_cost_function!()
189+
jcosts = copy(MTK.get_costs(sys))
190+
consolidate = MTK.get_consolidate(sys)
191+
if isnothing(jcosts) || isempty(jcosts)
192+
minimize!(opti, MX(0))
193+
return
194+
end
195+
196+
jcosts = substitute_model_vars(model, sys, pmap, jcosts; is_free_t)
197+
jcosts = substitute_free_final_vars(model, sys, pmap, jcosts; is_free_t)
198+
jcosts = substitute_fixed_t_vars(model, sys, pmap, jcosts; is_free_t)
199+
jcosts = substitute_integral()
200+
end
201+
202+
function add_user_constraints!()
203+
conssys = MTK.get_constraintsystem(sys)
204+
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
205+
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
206+
207+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
208+
jconstraints = substitute_model_vars(model, sys, pmap, jconstraints; auxmap, is_free_t)
209+
210+
for c in jconstraints
211+
if cons isa Equation
212+
add_constraint!()
213+
elseif cons.relational_op === Symbolics.geq
214+
add_constraint!()
215+
else
216+
add_constraint!()
217+
end
218+
end
219+
end
220+
221+
function generate_U end
222+
function generate_V end
223+
function generate_timescale end
224+
225+
function add_initial_constraints! end
226+
function add_constraint! end
227+
228+
function add_collocation_solve_constraints!(prob, tableau)
229+
nᵤ = size(U.u, 1)
230+
nᵥ = size(V.u, 1)
231+
232+
if is_explicit(tableau)
233+
K = MX[]
234+
for k in 1:(length(tsteps) - 1)
235+
τ = tsteps[k]
236+
for (i, h) in enumerate(c)
237+
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ)))
238+
Uₙ = U.u[:, k] + ΔU * dt
239+
Vₙ = V.u[:, k]
240+
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
241+
push!(K, Kₙ)
242+
end
243+
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
244+
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1])
245+
empty!(K)
246+
end
247+
else
248+
for k in 1:(length(tsteps) - 1)
249+
τ = tsteps[k]
250+
# Kᵢ = generate_K()
251+
Kᵢ = variable!(solver_opti, nᵤ, length(α))
252+
ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method
253+
for (i, h) in enumerate(c)
254+
ΔU = ΔUs[i, :]'
255+
Uₙ = U.u[:, k] + ΔU * dt
256+
Vₙ = V.u[:, k]
257+
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt))
258+
end
259+
ΔU_tot = dt * (Kᵢ * α)
260+
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1])
261+
end
262+
end
263+
end
264+
265+
function add_equational_solve_constraints!()
266+
diff_eqs = substitute_differentials()
267+
add_constraint!()
268+
269+
alg_eqs = substitute_model_vars()
270+
add_constraint!()
271+
end
272+
273+
"""
274+
Add the solve constraints, set the solver (Ipopt, e.g.)
275+
"""
276+
function prepare_solver end
277+
278+
function DiffEqBase.solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation)
279+
#add_solve_constraints!(prob, solver)
280+
solver = prepare_solver(prob, solver)
281+
sol = solve_prob(prob, solver)
282+
283+
ts = get_t_values(sol)
284+
Us = get_U_values(sol)
285+
Vs = get_V_values(sol)
286+
287+
ode_sol = DiffEqBase.build_solution(prob, solver, ts, Us)
288+
input_sol = DiffEqBase.build_solution(prob, solver, ts, Vs)
289+
290+
if successful_solve(model)
291+
ode_sol = SciMLBase.solution_new_retcode(
292+
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
293+
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
294+
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
295+
end
296+
DynamicOptSolution(model, ode_sol, input_sol)
297+
end

0 commit comments

Comments
 (0)