Skip to content

Commit e74066e

Browse files
committed
WIP: interface to TinyMPC
1 parent e95213e commit e74066e

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

src/mpc.jl

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
using ControlSystemsBase, Dates, LinearAlgebra, Plots, StaticArrays
2+
const V32 = Vector{Float32}
3+
const M32 = Matrix{Float32}
4+
const V64 = Vector{Float64}
5+
const M64 = Matrix{Float64}
6+
const A32 = Union{Array{Float32}, <:SubArray{Float32, 1, <:Array{Float32}}, <:SubArray{Float32, 2, <:Array{Float32}}}
7+
8+
cd(@__DIR__)
9+
10+
function compile_lib(dir::String)
11+
print("Compiling library to ", dir, "\n")
12+
build_dir = joinpath(dir, "build")
13+
if !isdir(build_dir)
14+
mkdir(build_dir)
15+
end
16+
run(`cmake -S $dir -B $build_dir`)
17+
run(`cmake --build $build_dir`)
18+
return true
19+
end
20+
21+
get_os_extension() = Sys.isapple() ? ".dylib" : Sys.iswindows() ? ".dll" : ".so"
22+
23+
24+
tinympc_dir = "/home/fredrikb/repos/tinympc-julia/tinympc/TinyMPC" # Path to the TinyMPC directory (C code)
25+
compile_lib(tinympc_dir) # Compile the C code into a shared library
26+
27+
tinympc = joinpath(tinympc_dir, "build","src","tinympc","libtinympcShared")*get_os_extension() # Path to the compiled library
28+
@assert isfile(tinympc) # Check that the library exists
29+
30+
struct MPCController2
31+
x::M32
32+
u::M32
33+
library_path::String
34+
library::Ptr{Cvoid}
35+
set_x0_ptr::Ptr{Cvoid}
36+
set_xref_ptr::Ptr{Cvoid}
37+
call_tiny_solve_ptr::Ptr{Cvoid}
38+
get_u_ptr::Ptr{Cvoid}
39+
get_x_ptr::Ptr{Cvoid}
40+
t::Vector{Float64}
41+
end
42+
43+
Base.size(c::MPCController2) = size(c.u), size(c.x)
44+
45+
function MPCController2(sys, Q1::M64, Q2::M64, N::Integer;
46+
x_min::V64 = fill(-1e6, sys.nx*N), # state constraints
47+
x_max::V64 = fill(1e6, sys.nx*N), # state constraints
48+
u_min::V64 = fill(-1e6, sys.nu*(N-1)), # input constraints
49+
u_max::V64 = fill(1e6, sys.nu*(N-1)), # input constraints
50+
rho::Float64 = 0.1,
51+
abs_pri_tol::Float64 = 1.0e-3, # absolute primal tolerance
52+
abs_dual_tol::Float64 = 1.0e-3, # absolute dual tolerance
53+
max_iter::Integer = 10000, # maximum number of iterations
54+
check_termination::Integer = 2, # whether to check termination and period
55+
output_dir = "generated_code_$(now())", # Path to the generated code
56+
verbose::Integer = true,
57+
compile = true,
58+
)
59+
60+
(; A, B, nx, nu) = sys
61+
isdiag(Q1) || throw(ArgumentError("Q1 must be diagonal"))
62+
isdiag(Q2) || throw(ArgumentError("Q2 must be diagonal"))
63+
size(Q1) == (nx, nx) || throw(ArgumentError("Q1 must have size nx x nx"))
64+
size(Q2) == (nu, nu) || throw(ArgumentError("Q2 must have size nu x nu"))
65+
if length(x_min) == nx
66+
x_min = repeat(x_min, N)
67+
end
68+
if length(x_max) == nx
69+
x_max = repeat(x_max, N)
70+
end
71+
if length(u_min) == nu
72+
u_min = repeat(u_min, N-1)
73+
end
74+
if length(u_max) == nu
75+
u_max = repeat(u_max, N-1)
76+
end
77+
length(x_min) == sys.nx*N || throw(ArgumentError("x_min must have length nx*N"))
78+
length(x_max) == sys.nx*N || throw(ArgumentError("x_max must have length nx*N"))
79+
length(u_min) == sys.nu*(N-1) || throw(ArgumentError("u_min must have length nu*(N-1)"))
80+
length(u_max) == sys.nu*(N-1) || throw(ArgumentError("u_max must have length nu*(N-1)"))
81+
82+
if !(eltype(A) <: Float64)
83+
A, B = convert(Matrix{Float64}, A), convert(Matrix{Float64}, B)
84+
end
85+
@ccall tinympc.tiny_codegen(Cint(nx)::Cint, Cint(nu)::Cint, Cint(N)::Cint, A::Ptr{Float64}, B::Ptr{Float64}, diag(Q1)::Ptr{Float64}, diag(Q2)::Ptr{Float64}, x_min::Ptr{Float64}, x_max::Ptr{Float64}, u_min::Ptr{Float64}, u_max::Ptr{Float64}, rho::Float64, abs_pri_tol::Float64, abs_dual_tol::Float64, Cint(max_iter)::Cint, Cint(check_termination)::Cint, Cint(verbose)::Cint, tinympc_dir::Ptr{UInt8}, output_dir::Ptr{UInt8})::Cint
86+
87+
library_path = joinpath(output_dir,"build","tinympc","libtinympcShared")*get_os_extension()
88+
if compile
89+
compile_lib(output_dir)
90+
library = Libc.dlopen(library_path)
91+
set_x0_ptr = Libc.dlsym(library, :set_x0)
92+
set_xref_ptr = Libc.dlsym(library, :set_xref)
93+
call_tiny_solve_ptr = Libc.dlsym(library, :call_tiny_solve)
94+
get_u_ptr = Libc.dlsym(library, :get_u)
95+
get_x_ptr = Libc.dlsym(library, :get_x)
96+
ccall(set_xref_ptr, Cvoid, (Ptr{Float32}, Cint), zeros(Float32, N*nx), Cint(0))
97+
else
98+
library = Ptr{Cvoid}(0)
99+
set_x0_ptr = Ptr{Cvoid}(0)
100+
set_xref_ptr = Ptr{Cvoid}(0)
101+
call_tiny_solve_ptr = Ptr{Cvoid}(0)
102+
get_u_ptr = Ptr{Cvoid}(0)
103+
get_x_ptr = Ptr{Cvoid}(0)
104+
end
105+
106+
MPCController2(
107+
zeros(Float32, sys.nx, N),
108+
zeros(Float32, sys.nu, (N-1)),
109+
library_path,
110+
library,
111+
set_x0_ptr,
112+
set_xref_ptr,
113+
call_tiny_solve_ptr,
114+
get_u_ptr,
115+
get_x_ptr,
116+
Float64[],
117+
)
118+
119+
end
120+
121+
function (controller::MPCController2)(x0::A32, r::Union{Nothing, A32}=nothing; verbose=false)
122+
(; set_x0_ptr,
123+
set_xref_ptr,
124+
call_tiny_solve_ptr,
125+
get_u_ptr,
126+
get_x_ptr) = controller
127+
128+
if r !== nothing
129+
if length(r) == size(controller.x, 1)
130+
r = repeat(r, size(controller.x, 2))
131+
end
132+
length(r) == length(controller.x) || throw(ArgumentError("r must have the same length as the state vector"))
133+
ccall(set_xref_ptr, Cvoid, (Ptr{Float32}, Cint), r, Cint(0))
134+
end
135+
136+
ccall(set_x0_ptr, Cvoid, (Ptr{Float32}, Cint), x0, Cint(0))
137+
138+
t = @elapsed ccall(call_tiny_solve_ptr, Cvoid, (Cint,), Cint(verbose))
139+
push!(controller.t, t)
140+
ccall(get_u_ptr, Cvoid, (Ptr{Float32}, Cint), controller.u, 0)
141+
ccall(get_x_ptr, Cvoid, (Ptr{Float32}, Cint), controller.x, 0)
142+
143+
(; controller.u, controller.x, t)
144+
end
145+
146+
##
147+
using RobustAndOptimalControl
148+
Ts = 0.05
149+
N = 100
150+
Pc = DemoSystems.double_mass_model(c0=0.01,c1=0.01,c2=0.01, outputs=1)
151+
P = c2d(Pc, Ts)
152+
P = add_input_integrator(P, 1; ϵ=1e-3)
153+
Q1 = Matrix(1.0*I(P.nx))
154+
Q2 = Matrix(0.1*I(P.nu))
155+
156+
157+
u_min = fill(-250.0, P.nu)
158+
u_max = fill(250.0, P.nu)
159+
x_max = [1000.0, 1000, 1000, 1000, 10]
160+
x_min = -x_max
161+
162+
##
163+
164+
Lmpc = MPCController2(P, Q1, Q2, N;
165+
x_min,
166+
x_max,
167+
u_min,
168+
u_max,
169+
rho = 1.0,
170+
max_iter = 3500,
171+
abs_pri_tol = 1.0e-3,
172+
abs_dual_tol = 1.0e-3,
173+
check_termination = 2,
174+
175+
)
176+
L = lqr(P, Q1, Q2)
177+
178+
# x0 = zeros(Float32, P.nx)
179+
x0 = Float32[20, 0, 0, 0, 0]
180+
r = reduce(hcat, fill(zeros(Float32, P.nx), N))
181+
182+
# u, x, t = Lmpc(x0, r, verbose=false)
183+
# plot((P.C*x)', layout=(2,1)); plot!(u', sp=2)
184+
185+
lqr_fun = (x,t)->clamp.(-L*x, u_min[1], u_max[1])
186+
mpc_closure =function (P, nu::Val{NU} = Val(P.nu)) where NU
187+
x_F32 = zeros(Float32, P.nx)
188+
u_F32 = zeros(Float32, P.nu)
189+
let Lmpc = Lmpc, u_min = u_min, u_max = u_max
190+
function (x,t)
191+
x_F32 .= x
192+
res = Lmpc(x_F32; verbose=false)
193+
@views u_F32 .= clamp.(
194+
res.u[1:NU, 1],
195+
u_min[1:NU, 1],
196+
u_max[1:NU, 1]
197+
)
198+
SVector{NU}(u_F32)
199+
end
200+
end
201+
202+
end
203+
204+
mpc_fun = mpc_closure(P)
205+
206+
@time "lqr" res_lqr = lsim(P, lqr_fun, 5; x0);
207+
@time "mpc" res_mpc = lsim(P, mpc_fun, 5; x0);
208+
209+
cost_lqr = dot(res_lqr.x, Q1, res_lqr.x) + dot(res_lqr.u, Q2, res_lqr.u)
210+
cost_mpc = dot(res_mpc.x, Q1, res_mpc.x) + dot(res_mpc.u, Q2, res_mpc.u)
211+
212+
maximum_constraint_violation = max(
213+
maximum(res_mpc.x .- x_max),
214+
maximum(x_min .- res_mpc.x),
215+
maximum(res_mpc.u .- u_max),
216+
maximum(u_min .- res_mpc.u)
217+
)
218+
219+
220+
fig1 = plot(res_lqr, label="LQR $cost_lqr", plotu=true, plotx=false, size=(800, 1200), margin=5Plots.mm)
221+
plot!(res_mpc, label="MPC $cost_mpc", plotu=true, plotx=false)
222+
fig2 = scatter(1e3 .* Lmpc.t, title="TimyMPC execution time [ms]", label=false)
223+
plot(fig1, fig2, layout=(1,2), size=(1200, 1200))
224+
display(current())
225+

0 commit comments

Comments
 (0)