Skip to content

Commit 8d268e9

Browse files
committed
refactor: move set_variable_bounds to interface function
1 parent 4ca744a commit 8d268e9

File tree

3 files changed

+51
-74
lines changed

3 files changed

+51
-74
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ struct MXLinearInterpolation
1717
t::Vector{Float64}
1818
dt::Float64
1919
end
20+
Base.getindex(m::MXLinearInterpolation, i...) = length(i) == length(size(m.u)) ? m.u[i...] : m.u[i..., :]
2021

2122
mutable struct CasADiModel
2223
model::Opti
@@ -37,7 +38,7 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
3738
u0::uType
3839
tspan::tType
3940
p::P
40-
model::CasADiModel
41+
wrapped_model::CasADiModel
4142
kwargs::K
4243

4344
function CasADiDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
@@ -52,10 +53,11 @@ function (M::MXLinearInterpolation)(τ)
5253
Δ = nt - i + 1
5354

5455
(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
56+
colons = ntuple(_ -> (:), length(size(M.u)) - 1)
5557
if i < length(M.t)
56-
M.u[:, i] + Δ * (M.u[:, i + 1] - M.u[:, i])
58+
M.u[colons..., i] + Δ*(M.u[colons..., i+1] - M.u[colons..., i])
5759
else
58-
M.u[:, i]
60+
M.u[colons..., i]
5961
end
6062
end
6163

@@ -121,29 +123,6 @@ function MTK.add_constraint!(m::CasADiModel, expr)
121123
end
122124
MTK.set_objective!(m::CasADiModel, expr) = minimize!(m.model, MX(expr))
123125

124-
function MTK.set_variable_bounds!(m::CasADiModel, sys, pmap, tf)
125-
@unpack model, U, tₛ, V = m
126-
for (i, u) in enumerate(unknowns(sys))
127-
if MTK.hasbounds(u)
128-
lo, hi = MTK.getbounds(u)
129-
subject_to!(model, Symbolics.fixpoint_sub(lo, pmap) <= U.u[i, :])
130-
subject_to!(model, U.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
131-
end
132-
end
133-
for (i, v) in enumerate(MTK.unbound_inputs(sys))
134-
if MTK.hasbounds(v)
135-
lo, hi = MTK.getbounds(v)
136-
subject_to!(model, Symbolics.fixpoint_sub(lo, pmap) <= V.u[i, :])
137-
subject_to!(model, V.u[i, :] <= Symbolics.fixpoint_sub(hi, pmap))
138-
end
139-
end
140-
if MTK.symbolic_type(tf) === MTK.ScalarSymbolic() && hasbounds(tf)
141-
lo, hi = MTK.getbounds(tf)
142-
subject_to!(model, tₛ >= lo)
143-
subject_to!(model, tₛ <= hi)
144-
end
145-
end
146-
147126
function MTK.add_initial_constraints!(m::CasADiModel, u0, u0_idxs, args...)
148127
@unpack model, U = m
149128
for i in u0_idxs
@@ -219,8 +198,8 @@ end
219198

220199
function add_solve_constraints!(prob::CasADiDynamicOptProblem, tableau)
221200
@unpack A, α, c = tableau
222-
@unpack model, f, p = prob
223-
@unpack model, U, V, tₛ = model
201+
@unpack wrapped_model, f, p = prob
202+
@unpack model, U, V, tₛ = wrapped_model
224203
solver_opti = copy(model)
225204

226205
tsteps = U.t
@@ -281,7 +260,7 @@ function MTK.prepare_and_optimize!(prob::CasADiDynamicOptProblem, solver::CasADi
281260
CasADi.solve!(solver_opti)
282261
catch ErrorException
283262
end
284-
prob.model.solver_opti = solver_opti
263+
prob.wrapped_model.solver_opti = solver_opti
285264
end
286265

287266
function MTK.get_U_values(model::CasADiModel)

ext/MTKInfiniteOptExt.jl

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct JuMPDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
2323
u0::uType
2424
tspan::tType
2525
p::P
26-
model::InfiniteOptModel
26+
wrapped_model::InfiniteOptModel
2727
kwargs::K
2828

2929
function JuMPDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
@@ -38,7 +38,7 @@ struct InfiniteOptDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
3838
u0::uType
3939
tspan::tType
4040
p::P
41-
model::InfiniteOptModel
41+
wrapped_model::InfiniteOptModel
4242
kwargs::K
4343

4444
function InfiniteOptDynamicOptProblem(f, u0, tspan, p, model, kwargs...)
@@ -110,34 +110,10 @@ function MTK.InfiniteOptDynamicOptProblem(sys::System, u0map, tspan, pmap;
110110
steps = nothing,
111111
guesses = Dict(), kwargs...)
112112
prob = MTK.process_DynamicOptProblem(InfiniteOptDynamicOptProblem, InfiniteOptModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs...)
113-
MTK.add_equational_constraints!(prob.model, sys, pmap, tspan)
113+
MTK.add_equational_constraints!(prob.wrapped_model, sys, pmap, tspan)
114114
prob
115115
end
116116

117-
function MTK.set_variable_bounds!(model::InfiniteOptModel, sys, pmap, tf)
118-
for (i, u) in enumerate(unknowns(sys))
119-
if MTK.hasbounds(u)
120-
lo, hi = MTK.getbounds(u)
121-
set_lower_bound(model.U[i], Symbolics.fixpoint_sub(lo, pmap))
122-
set_upper_bound(model.U[i], Symbolics.fixpoint_sub(hi, pmap))
123-
end
124-
end
125-
126-
for (i, v) in enumerate(MTK.unbound_inputs(sys))
127-
if MTK.hasbounds(v)
128-
lo, hi = MTK.getbounds(v)
129-
set_lower_bound(model.V[i], Symbolics.fixpoint_sub(lo, pmap))
130-
set_upper_bound(model.V[i], Symbolics.fixpoint_sub(hi, pmap))
131-
end
132-
end
133-
134-
if MTK.symbolic_type(tf) === MTK.ScalarSymbolic() && hasbounds(tf)
135-
lo, hi = MTK.getbounds(tf)
136-
set_lower_bound(model.tₛ, lo)
137-
set_upper_bound(model.tₛ, hi)
138-
end
139-
end
140-
141117
function MTK.substitute_integral(model, exprs, tspan)
142118
intmap = Dict()
143119
for int in MTK.collect_applied_operators(exprs, Symbolics.Integral)
@@ -191,14 +167,12 @@ end
191167

192168
function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
193169
@unpack A, α, c = tableau
194-
@unpack model, f, p = prob
195-
t = model.model[:t]
170+
@unpack wrapped_model, f, p = prob
171+
@unpack tₛ, U, V, model = wrapped_model
172+
t = model[:t]
196173
tsteps = supports(t)
197174
dt = tsteps[2] - tsteps[1]
198175

199-
tₛ = model.tₛ
200-
U = model.U
201-
V = model.V
202176
nᵤ = length(U)
203177
nᵥ = length(V)
204178
if MTK.is_explicit(tableau)
@@ -212,22 +186,22 @@ function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
212186
push!(K, Kₙ)
213187
end
214188
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
215-
@constraint(model.model, [n = 1:nᵤ], U[n](τ) + ΔU[n]==U[n](τ + dt),
189+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU[n]==U[n](τ + dt),
216190
base_name="solve_time_")
217191
empty!(K)
218192
end
219193
else
220-
@variable(model.model, K[1:length(α), 1:nᵤ], Infinite(t))
194+
K = @variable(model, K[1:length(α), 1:nᵤ], Infinite(model[:t]))
221195
ΔUs = A * K
222196
ΔU_tot = dt * (K' * α)
223197
for τ in tsteps[1:end-1]
224198
for (i, h) in enumerate(c)
225199
ΔU = @view ΔUs[i, :]
226-
Uₙ = U + ΔU * h * dt
227-
@constraint(model.model, [j = 1:nᵤ], K[i, j]==(tₛ * f(Uₙ, V, p, τ + h * dt)[j]),
200+
Uₙ = U + ΔU * dt
201+
@constraint(model, [j = 1:nᵤ], K[i, j]==(tₛ * f(Uₙ, V, p, τ + h * dt)[j]),
228202
DomainRestrictions(t => τ), base_name="solve_K$i()")
229203
end
230-
@constraint(model.model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](min+ dt, tsteps[end])),
204+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](min+ dt, tsteps[end])),
231205
DomainRestrictions(t => τ), base_name="solve_U()")
232206
end
233207
end
@@ -256,7 +230,7 @@ end
256230
MTK.InfiniteOptCollocation(solver, derivative_method = InfiniteOpt.FiniteDifference(InfiniteOpt.Backward())) = InfiniteOptCollocation(solver, derivative_method)
257231

258232
function MTK.prepare_and_optimize!(prob::JuMPDynamicOptProblem, solver::JuMPCollocation; verbose = false, kwargs...)
259-
model = prob.model.model
233+
model = prob.wrapped_model.model
260234
verbose || set_silent(model)
261235
# Unregister current solver constraints
262236
for con in all_constraints(model)
@@ -278,7 +252,7 @@ function MTK.prepare_and_optimize!(prob::JuMPDynamicOptProblem, solver::JuMPColl
278252
end
279253

280254
function MTK.prepare_and_optimize!(prob::InfiniteOptDynamicOptProblem, solver::InfiniteOptCollocation; verbose = false, kwargs...)
281-
model = prob.model.model
255+
model = prob.wrapped_model.model
282256
verbose || set_silent(model)
283257
set_derivative_method(model[:t], solver.derivative_method)
284258
set_optimizer(model, solver.solver)

src/systems/optimal_control_interface.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,30 @@ function generate_timescale! end
207207
function set_variable_bounds! end
208208
function add_initial_constraints! end
209209
function add_constraint! end
210+
211+
function set_variable_bounds!(m, sys, pmap, tf)
212+
@unpack model, U, V, tₛ = m
213+
for (i, u) in enumerate(unknowns(sys))
214+
if hasbounds(u)
215+
lo, hi = getbounds(u)
216+
add_constraint!(m, U[i] Symbolics.fixpoint_sub(lo, pmap))
217+
add_constraint!(m, U[i] Symbolics.fixpoint_sub(hi, pmap))
218+
end
219+
end
220+
for (i, v) in enumerate(unbound_inputs(sys))
221+
if hasbounds(v)
222+
lo, hi = getbounds(v)
223+
add_constraint!(m, V[i] Symbolics.fixpoint_sub(lo, pmap))
224+
add_constraint!(m, V[i] Symbolics.fixpoint_sub(hi, pmap))
225+
end
226+
end
227+
if symbolic_type(tf) === ScalarSymbolic() && hasbounds(tf)
228+
lo, hi = getbounds(tf)
229+
set_lower_bound(tₛ, Symbolics.fixpoint_sub(lo, pmap))
230+
set_upper_bound(tₛ, Symbolics.fixpoint_sub(hi, pmap))
231+
end
232+
end
233+
210234
is_free_final(model) = model.is_free_final
211235

212236
function add_cost_function!(model, sys, tspan, pmap)
@@ -304,19 +328,19 @@ function successful_solve end
304328
function DiffEqBase.solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation; verbose = false, kwargs...)
305329
prepare_and_optimize!(prob, solver; verbose, kwargs...)
306330

307-
ts = get_t_values(prob.model)
308-
Us = get_U_values(prob.model)
309-
Vs = get_V_values(prob.model)
310-
is_free_final(prob.model) && (ts .+ prob.tspan[1])
331+
ts = get_t_values(prob.wrapped_model)
332+
Us = get_U_values(prob.wrapped_model)
333+
Vs = get_V_values(prob.wrapped_model)
334+
is_free_final(prob.wrapped_model) && (ts .+ prob.tspan[1])
311335

312336
ode_sol = DiffEqBase.build_solution(prob, solver, ts, Us)
313337
input_sol = isnothing(Vs) ? nothing : DiffEqBase.build_solution(prob, solver, ts, Vs)
314338

315-
if !successful_solve(prob.model)
339+
if !successful_solve(prob.wrapped_model)
316340
ode_sol = SciMLBase.solution_new_retcode(
317341
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
318342
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
319343
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
320344
end
321-
DynamicOptSolution(prob.model.model, ode_sol, input_sol)
345+
DynamicOptSolution(prob.wrapped_model.model, ode_sol, input_sol)
322346
end

0 commit comments

Comments
 (0)