Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
BoundaryValueDiffEqAscher = "1.1.0"
BoundaryValueDiffEqMIRK = "1.4.0"
BoundaryValueDiffEqAscher = "1.6.0"
BoundaryValueDiffEqMIRK = "1.8.0"
CasADi = "1.0.6"
ChainRulesCore = "1"
Combinatorics = "1"
Expand Down
66 changes: 36 additions & 30 deletions ext/MTKCasADiDynamicOptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
end
end

function (M::MXLinearInterpolation)(τ)
function (M::MXLinearInterpolation)(τ)
nt = (τ - M.t[1]) / M.dt
i = 1 + floor(Int, nt)
Δ = nt - i + 1

(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
if i < length(M.t)
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
M.u[:, i] + Δ * (M.u[:, i + 1] - M.u[:, i])
else
M.u[:, i]
end
Expand All @@ -74,7 +74,7 @@ The constraints are:
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
dt = nothing,
steps = nothing,
guesses = Dict(), kwargs...)
guesses = Dict(), kwargs...)
MTK.warn_overdetermined(sys, u0map)
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
Expand Down Expand Up @@ -104,21 +104,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
subject_to!(opti, tₛ >= lo)
subject_to!(opti, tₛ >= hi)
end
pmap[te_sym] = tₛ
pmap[te_sym] = tₛ
tsteps = LinRange(0, 1, steps)
else
tₛ = MX(1)
tsteps = LinRange(tspan[1], tspan[2], steps)
end

U = CasADi.variable!(opti, length(states), steps)
V = CasADi.variable!(opti, length(ctrls), steps)
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
c0 = MTK.value.([pmap[c] for c in ctrls])
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))

U_interp = MXLinearInterpolation(U, tsteps, tsteps[2]-tsteps[1])
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2]-tsteps[1])
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
for (i, ct) in enumerate(ctrls)
pmap[ct] = V[i, :]
end
Expand Down Expand Up @@ -185,8 +185,8 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
x = MTK.operation(st)
t = only(MTK.arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
Expand All @@ -196,11 +196,11 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
end

if cons isa Equation
subject_to!(opti, cons.lhs - cons.rhs==0)
subject_to!(opti, cons.lhs - cons.rhs == 0)
elseif cons.relational_op === Symbolics.geq
subject_to!(opti, cons.lhs - cons.rhs0)
subject_to!(opti, cons.lhs - cons.rhs0)
else
subject_to!(opti, cons.lhs - cons.rhs0)
subject_to!(opti, cons.lhs - cons.rhs0)
end
end
end
Expand All @@ -227,8 +227,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
x = operation(st)
t = only(arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
Expand All @@ -244,7 +244,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
op = MTK.operation(int)
arg = only(arguments(MTK.value(int)))
lo, hi = (op.domain.domain.left, op.domain.domain.right)
!isequal((lo, hi), tspan) && error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
!isequal((lo, hi), tspan) &&
error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
# Approximate integral as sum.
intmap[int] = dt * tₛ * sum(arg)
end
Expand All @@ -253,7 +254,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
end

function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
function substitute_casadi_vars(
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
@unpack opti, U, V, tₛ = model
iv = MTK.get_iv(sys)
sts = unknowns(sys)
Expand Down Expand Up @@ -281,44 +283,44 @@ end

function add_solve_constraints(prob, tableau)
@unpack A, α, c = tableau
@unpack model, f, p = prob
@unpack model, f, p = prob
@unpack opti, U, V, tₛ = model
solver_opti = copy(opti)

tsteps = U.t
tsteps = U.t
dt = tsteps[2] - tsteps[1]

nᵤ = size(U.u, 1)
nᵥ = size(V.u, 1)

if MTK.is_explicit(tableau)
K = MX[]
for k in 1:length(tsteps)-1
for k in 1:(length(tsteps) - 1)
τ = tsteps[k]
for (i, h) in enumerate(c)
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ)))
Uₙ = U.u[:, k] + ΔU*dt
Uₙ = U.u[:, k] + ΔU * dt
Vₙ = V.u[:, k]
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
push!(K, Kₙ)
end
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k+1])
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1])
empty!(K)
end
else
for k in 1:length(tsteps)-1
for k in 1:(length(tsteps) - 1)
τ = tsteps[k]
Kᵢ = variable!(solver_opti, nᵤ, length(α))
ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method
for (i, h) in enumerate(c)
ΔU = ΔUs[i,:]'
Uₙ = U.u[:,k] + ΔU*dt
Vₙ = V.u[:,k]
subject_to!(solver_opti, Kᵢ[:,i] == tₛ * f(Uₙ, Vₙ, p, τ + h*dt))
ΔU = ΔUs[i, :]'
Uₙ = U.u[:, k] + ΔU * dt
Vₙ = V.u[:, k]
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt))
end
ΔU_tot = dt*(Kᵢ*α)
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:,k+1])
ΔU_tot = dt * (Kᵢ * α)
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1])
end
end
solver_opti
Expand All @@ -331,7 +333,10 @@ end

NOTE: the solver should be passed in as a string to CasADi. "ipopt"
"""
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt", tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(), solver_options::Dict = Dict(), silent = false)
function DiffEqBase.solve(
prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt",
tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(),
solver_options::Dict = Dict(), silent = false)
@unpack model, u0, p, tspan, f = prob
tableau = tableau_getter()
@unpack opti, U, V, tₛ = model
Expand Down Expand Up @@ -366,7 +371,8 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
end

if failed
ode_sol = SciMLBase.solution_new_retcode(ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
ode_sol = SciMLBase.solution_new_retcode(
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
end
Expand Down
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
function FMIComponent end

include("systems/optimal_control_interface.jl")
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem, CasADiDynamicOptProblem
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem,
CasADiDynamicOptProblem
export DynamicOptSolution

end # module
2 changes: 1 addition & 1 deletion src/systems/optimal_control_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function constructDefault(T::Type = Float64)
A = map(T, A)
α = map(T, α)
c = map(T, c)

DiffEqBase.ImplicitRKTableau(A, c, α, 5)
end

Expand Down
8 changes: 4 additions & 4 deletions test/bvproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ daesolvers = [Ascher2, Ascher4, Ascher6]

for solver in solvers
sol = solve(bvp, solver(), dt = 0.01)
@test_broken isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test_broken sol.u[1] == [1.0, 2.0]
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1.0, 2.0]
end

# Test out of place
Expand All @@ -39,8 +39,8 @@ daesolvers = [Ascher2, Ascher4, Ascher6]

for solver in solvers
sol = solve(bvp2, solver(), dt = 0.01)
@test_broken isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test_broken sol.u[1] == [1.0, 2.0]
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1.0, 2.0]
end
end

Expand Down
8 changes: 5 additions & 3 deletions test/extensions/dynamic_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ const M = ModelingToolkit
@test jsol.sol(0.6)[1] ≈ 3.5
@test jsol.sol(0.3)[1] ≈ 7.0

cprob = CasADiDynamicOptProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
cprob = CasADiDynamicOptProblem(
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
csol = solve(cprob, "ipopt", constructTsitouras5, silent = true)
@test csol.sol(0.6)[1] ≈ 3.5
@test csol.sol(0.3)[1] ≈ 7.0
Expand All @@ -87,7 +88,8 @@ const M = ModelingToolkit
jsol = solve(jprob, Ipopt.Optimizer, constructRadauIA3, silent = true) # 12.190 s, 9.68 GiB
@test all(u -> u > [1, 1], jsol.sol.u)

cprob = CasADiDynamicOptProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
cprob = CasADiDynamicOptProblem(
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
csol = solve(cprob, "ipopt", constructRadauIA3, silent = true)
@test all(u -> u > [1, 1], csol.sol.u)
end
Expand Down Expand Up @@ -220,7 +222,7 @@ end
jprob = JuMPDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
jsol = solve(jprob, Ipopt.Optimizer, constructRadauIIA5, silent = true)
@test jsol.sol.u[end][1] > 1.012

cprob = CasADiDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
csol = solve(cprob, "ipopt"; silent = true)
@test csol.sol.u[end][1] > 1.012
Expand Down
Loading