|
| 1 | +using ControlSystemsBase, Dates, LinearAlgebra, Plots, StaticArrays |
| 2 | +const V32 = Vector{Float32} |
| 3 | +const M32 = Matrix{Float32} |
| 4 | +const V64 = Vector{Float64} |
| 5 | +const M64 = Matrix{Float64} |
| 6 | +const A32 = Union{Array{Float32}, <:SubArray{Float32, 1, <:Array{Float32}}, <:SubArray{Float32, 2, <:Array{Float32}}} |
| 7 | + |
| 8 | +cd(@__DIR__) |
| 9 | + |
| 10 | +function compile_lib(dir::String) |
| 11 | + print("Compiling library to ", dir, "\n") |
| 12 | + build_dir = joinpath(dir, "build") |
| 13 | + if !isdir(build_dir) |
| 14 | + mkdir(build_dir) |
| 15 | + end |
| 16 | + run(`cmake -S $dir -B $build_dir`) |
| 17 | + run(`cmake --build $build_dir`) |
| 18 | + return true |
| 19 | +end |
| 20 | + |
| 21 | +get_os_extension() = Sys.isapple() ? ".dylib" : Sys.iswindows() ? ".dll" : ".so" |
| 22 | + |
| 23 | + |
| 24 | +tinympc_dir = "/home/fredrikb/repos/tinympc-julia/tinympc/TinyMPC" # Path to the TinyMPC directory (C code) |
| 25 | +compile_lib(tinympc_dir) # Compile the C code into a shared library |
| 26 | + |
| 27 | +tinympc = joinpath(tinympc_dir, "build","src","tinympc","libtinympcShared")*get_os_extension() # Path to the compiled library |
| 28 | +@assert isfile(tinympc) # Check that the library exists |
| 29 | + |
| 30 | +struct MPCController2 |
| 31 | + x::M32 |
| 32 | + u::M32 |
| 33 | + library_path::String |
| 34 | + library::Ptr{Cvoid} |
| 35 | + set_x0_ptr::Ptr{Cvoid} |
| 36 | + set_xref_ptr::Ptr{Cvoid} |
| 37 | + call_tiny_solve_ptr::Ptr{Cvoid} |
| 38 | + get_u_ptr::Ptr{Cvoid} |
| 39 | + get_x_ptr::Ptr{Cvoid} |
| 40 | + t::Vector{Float64} |
| 41 | +end |
| 42 | + |
| 43 | +Base.size(c::MPCController2) = size(c.u), size(c.x) |
| 44 | + |
| 45 | +function MPCController2(sys, Q1::M64, Q2::M64, N::Integer; |
| 46 | + x_min::V64 = fill(-1e6, sys.nx*N), # state constraints |
| 47 | + x_max::V64 = fill(1e6, sys.nx*N), # state constraints |
| 48 | + u_min::V64 = fill(-1e6, sys.nu*(N-1)), # input constraints |
| 49 | + u_max::V64 = fill(1e6, sys.nu*(N-1)), # input constraints |
| 50 | + rho::Float64 = 0.1, |
| 51 | + abs_pri_tol::Float64 = 1.0e-3, # absolute primal tolerance |
| 52 | + abs_dual_tol::Float64 = 1.0e-3, # absolute dual tolerance |
| 53 | + max_iter::Integer = 10000, # maximum number of iterations |
| 54 | + check_termination::Integer = 2, # whether to check termination and period |
| 55 | + output_dir = "generated_code_$(now())", # Path to the generated code |
| 56 | + verbose::Integer = true, |
| 57 | + compile = true, |
| 58 | +) |
| 59 | + |
| 60 | + (; A, B, nx, nu) = sys |
| 61 | + isdiag(Q1) || throw(ArgumentError("Q1 must be diagonal")) |
| 62 | + isdiag(Q2) || throw(ArgumentError("Q2 must be diagonal")) |
| 63 | + size(Q1) == (nx, nx) || throw(ArgumentError("Q1 must have size nx x nx")) |
| 64 | + size(Q2) == (nu, nu) || throw(ArgumentError("Q2 must have size nu x nu")) |
| 65 | + if length(x_min) == nx |
| 66 | + x_min = repeat(x_min, N) |
| 67 | + end |
| 68 | + if length(x_max) == nx |
| 69 | + x_max = repeat(x_max, N) |
| 70 | + end |
| 71 | + if length(u_min) == nu |
| 72 | + u_min = repeat(u_min, N-1) |
| 73 | + end |
| 74 | + if length(u_max) == nu |
| 75 | + u_max = repeat(u_max, N-1) |
| 76 | + end |
| 77 | + length(x_min) == sys.nx*N || throw(ArgumentError("x_min must have length nx*N")) |
| 78 | + length(x_max) == sys.nx*N || throw(ArgumentError("x_max must have length nx*N")) |
| 79 | + length(u_min) == sys.nu*(N-1) || throw(ArgumentError("u_min must have length nu*(N-1)")) |
| 80 | + length(u_max) == sys.nu*(N-1) || throw(ArgumentError("u_max must have length nu*(N-1)")) |
| 81 | + |
| 82 | + if !(eltype(A) <: Float64) |
| 83 | + A, B = convert(Matrix{Float64}, A), convert(Matrix{Float64}, B) |
| 84 | + end |
| 85 | + @ccall tinympc.tiny_codegen(Cint(nx)::Cint, Cint(nu)::Cint, Cint(N)::Cint, A::Ptr{Float64}, B::Ptr{Float64}, diag(Q1)::Ptr{Float64}, diag(Q2)::Ptr{Float64}, x_min::Ptr{Float64}, x_max::Ptr{Float64}, u_min::Ptr{Float64}, u_max::Ptr{Float64}, rho::Float64, abs_pri_tol::Float64, abs_dual_tol::Float64, Cint(max_iter)::Cint, Cint(check_termination)::Cint, Cint(verbose)::Cint, tinympc_dir::Ptr{UInt8}, output_dir::Ptr{UInt8})::Cint |
| 86 | + |
| 87 | + library_path = joinpath(output_dir,"build","tinympc","libtinympcShared")*get_os_extension() |
| 88 | + if compile |
| 89 | + compile_lib(output_dir) |
| 90 | + library = Libc.dlopen(library_path) |
| 91 | + set_x0_ptr = Libc.dlsym(library, :set_x0) |
| 92 | + set_xref_ptr = Libc.dlsym(library, :set_xref) |
| 93 | + call_tiny_solve_ptr = Libc.dlsym(library, :call_tiny_solve) |
| 94 | + get_u_ptr = Libc.dlsym(library, :get_u) |
| 95 | + get_x_ptr = Libc.dlsym(library, :get_x) |
| 96 | + ccall(set_xref_ptr, Cvoid, (Ptr{Float32}, Cint), zeros(Float32, N*nx), Cint(0)) |
| 97 | + else |
| 98 | + library = Ptr{Cvoid}(0) |
| 99 | + set_x0_ptr = Ptr{Cvoid}(0) |
| 100 | + set_xref_ptr = Ptr{Cvoid}(0) |
| 101 | + call_tiny_solve_ptr = Ptr{Cvoid}(0) |
| 102 | + get_u_ptr = Ptr{Cvoid}(0) |
| 103 | + get_x_ptr = Ptr{Cvoid}(0) |
| 104 | + end |
| 105 | + |
| 106 | + MPCController2( |
| 107 | + zeros(Float32, sys.nx, N), |
| 108 | + zeros(Float32, sys.nu, (N-1)), |
| 109 | + library_path, |
| 110 | + library, |
| 111 | + set_x0_ptr, |
| 112 | + set_xref_ptr, |
| 113 | + call_tiny_solve_ptr, |
| 114 | + get_u_ptr, |
| 115 | + get_x_ptr, |
| 116 | + Float64[], |
| 117 | + ) |
| 118 | + |
| 119 | +end |
| 120 | + |
| 121 | +function (controller::MPCController2)(x0::A32, r::Union{Nothing, A32}=nothing; verbose=false) |
| 122 | + (; set_x0_ptr, |
| 123 | + set_xref_ptr, |
| 124 | + call_tiny_solve_ptr, |
| 125 | + get_u_ptr, |
| 126 | + get_x_ptr) = controller |
| 127 | + |
| 128 | + if r !== nothing |
| 129 | + if length(r) == size(controller.x, 1) |
| 130 | + r = repeat(r, size(controller.x, 2)) |
| 131 | + end |
| 132 | + length(r) == length(controller.x) || throw(ArgumentError("r must have the same length as the state vector")) |
| 133 | + ccall(set_xref_ptr, Cvoid, (Ptr{Float32}, Cint), r, Cint(0)) |
| 134 | + end |
| 135 | + |
| 136 | + ccall(set_x0_ptr, Cvoid, (Ptr{Float32}, Cint), x0, Cint(0)) |
| 137 | + |
| 138 | + t = @elapsed ccall(call_tiny_solve_ptr, Cvoid, (Cint,), Cint(verbose)) |
| 139 | + push!(controller.t, t) |
| 140 | + ccall(get_u_ptr, Cvoid, (Ptr{Float32}, Cint), controller.u, 0) |
| 141 | + ccall(get_x_ptr, Cvoid, (Ptr{Float32}, Cint), controller.x, 0) |
| 142 | + |
| 143 | + (; controller.u, controller.x, t) |
| 144 | +end |
| 145 | + |
| 146 | +## |
| 147 | +using RobustAndOptimalControl |
| 148 | +Ts = 0.05 |
| 149 | +N = 100 |
| 150 | +Pc = DemoSystems.double_mass_model(c0=0.01,c1=0.01,c2=0.01, outputs=1) |
| 151 | +P = c2d(Pc, Ts) |
| 152 | +P = add_input_integrator(P, 1; ϵ=1e-3) |
| 153 | +Q1 = Matrix(1.0*I(P.nx)) |
| 154 | +Q2 = Matrix(0.1*I(P.nu)) |
| 155 | + |
| 156 | + |
| 157 | +u_min = fill(-250.0, P.nu) |
| 158 | +u_max = fill(250.0, P.nu) |
| 159 | +x_max = [1000.0, 1000, 1000, 1000, 10] |
| 160 | +x_min = -x_max |
| 161 | + |
| 162 | +## |
| 163 | + |
| 164 | +Lmpc = MPCController2(P, Q1, Q2, N; |
| 165 | + x_min, |
| 166 | + x_max, |
| 167 | + u_min, |
| 168 | + u_max, |
| 169 | + rho = 1.0, |
| 170 | + max_iter = 3500, |
| 171 | + abs_pri_tol = 1.0e-3, |
| 172 | + abs_dual_tol = 1.0e-3, |
| 173 | + check_termination = 2, |
| 174 | + |
| 175 | +) |
| 176 | +L = lqr(P, Q1, Q2) |
| 177 | + |
| 178 | +# x0 = zeros(Float32, P.nx) |
| 179 | +x0 = Float32[20, 0, 0, 0, 0] |
| 180 | +r = reduce(hcat, fill(zeros(Float32, P.nx), N)) |
| 181 | + |
| 182 | +# u, x, t = Lmpc(x0, r, verbose=false) |
| 183 | +# plot((P.C*x)', layout=(2,1)); plot!(u', sp=2) |
| 184 | + |
| 185 | +lqr_fun = (x,t)->clamp.(-L*x, u_min[1], u_max[1]) |
| 186 | +mpc_closure =function (P, nu::Val{NU} = Val(P.nu)) where NU |
| 187 | + x_F32 = zeros(Float32, P.nx) |
| 188 | + u_F32 = zeros(Float32, P.nu) |
| 189 | + let Lmpc = Lmpc, u_min = u_min, u_max = u_max |
| 190 | + function (x,t) |
| 191 | + x_F32 .= x |
| 192 | + res = Lmpc(x_F32; verbose=false) |
| 193 | + @views u_F32 .= clamp.( |
| 194 | + res.u[1:NU, 1], |
| 195 | + u_min[1:NU, 1], |
| 196 | + u_max[1:NU, 1] |
| 197 | + ) |
| 198 | + SVector{NU}(u_F32) |
| 199 | + end |
| 200 | + end |
| 201 | + |
| 202 | +end |
| 203 | + |
| 204 | +mpc_fun = mpc_closure(P) |
| 205 | + |
| 206 | +@time "lqr" res_lqr = lsim(P, lqr_fun, 5; x0); |
| 207 | +@time "mpc" res_mpc = lsim(P, mpc_fun, 5; x0); |
| 208 | + |
| 209 | +cost_lqr = dot(res_lqr.x, Q1, res_lqr.x) + dot(res_lqr.u, Q2, res_lqr.u) |
| 210 | +cost_mpc = dot(res_mpc.x, Q1, res_mpc.x) + dot(res_mpc.u, Q2, res_mpc.u) |
| 211 | + |
| 212 | +maximum_constraint_violation = max( |
| 213 | + maximum(res_mpc.x .- x_max), |
| 214 | + maximum(x_min .- res_mpc.x), |
| 215 | + maximum(res_mpc.u .- u_max), |
| 216 | + maximum(u_min .- res_mpc.u) |
| 217 | +) |
| 218 | + |
| 219 | + |
| 220 | +fig1 = plot(res_lqr, label="LQR $cost_lqr", plotu=true, plotx=false, size=(800, 1200), margin=5Plots.mm) |
| 221 | +plot!(res_mpc, label="MPC $cost_mpc", plotu=true, plotx=false) |
| 222 | +fig2 = scatter(1e3 .* Lmpc.t, title="TimyMPC execution time [ms]", label=false) |
| 223 | +plot(fig1, fig2, layout=(1,2), size=(1200, 1200)) |
| 224 | +display(current()) |
| 225 | + |
0 commit comments