Skip to content

Commit 908d8cf

Browse files
committed
added: multithreading for MultipleShooting and TrapezoidalCollocation
1 parent b360566 commit 908d8cf

File tree

2 files changed

+61
-54
lines changed

2 files changed

+61
-54
lines changed

src/controller/transcription.jl

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ plant model/constraints.
3131
struct SingleShooting <: ShootingMethod end
3232

3333
@doc raw"""
34-
MultipleShooting(threads=false)
34+
MultipleShooting(; f_threads=false, h_threads=false)
3535
3636
Construct a direct multiple shooting [`TranscriptionMethod`](@ref).
3737
@@ -56,28 +56,33 @@ for [`NonLinModel`](@ref).
5656
This transcription computes the predictions by calling the augmented discrete-time model
5757
in the equality constraint function recursively over ``H_p``, or by updating the linear
5858
equality constraint vector for [`LinModel`](@ref). It is generally more efficient for large
59-
control horizon ``H_c``, unstable or highly nonlinear plant models/constraints.
59+
control horizon ``H_c``, unstable or highly nonlinear models/constraints. Multithreading
60+
with `threads_f` or `threads_h` keyword arguments can be advantageous if ``\mathbf{f}`` or
61+
``\mathbf{h}`` in the [`NonLinModel`](@ref) is expensive to evaluate, respectively.
6062
6163
Sparse optimizers like `OSQP` or `Ipopt` and sparse Jacobian computations are recommended
6264
for this transcription method.
6365
"""
64-
struct MultipleShooting{T} <: ShootingMethod
65-
function MultipleShooting(thread::Bool=false)
66-
return new{thread}()
66+
struct MultipleShooting <: ShootingMethod
67+
f_threads::Bool
68+
h_threads::Bool
69+
function MultipleShooting(; f_threads=false, h_threads=false)
70+
return new(f_threads, h_threads)
6771
end
6872
end
6973

7074
@doc raw"""
71-
TrapezoidalCollocation(h::Int=0)
75+
TrapezoidalCollocation(h::Int=0; f_threads=false, h_threads=false)
7276
7377
Construct an implicit trapezoidal [`TranscriptionMethod`](@ref) with `h`th order hold.
7478
7579
This is the simplest collocation method. It supports continuous-time [`NonLinModel`](@ref)s
7680
only. The decision variables are the same as for [`MultipleShooting`](@ref), hence similar
77-
computational costs. The `h` argument is `0` or `1`, for piecewise constant or linear
78-
manipulated inputs ``\mathbf{u}`` (`h=1` is slightly less expensive). Note that the various
79-
[`DiffSolver`](@ref) here assume zero-order hold, so `h=1` will induce a plant-model
80-
mismatch if the plant is simulated with these solvers.
81+
computational costs. See the same docstring for descriptions of `f_threads` and `h_threads`
82+
keywords. The `h` argument is `0` or `1`, for piecewise constant or linear manipulated
83+
inputs ``\mathbf{u}`` (`h=1` is slightly less expensive). Note that the various [`DiffSolver`](@ref)
84+
here assume zero-order hold, so `h=1` will induce a plant-model mismatch if the plant is
85+
simulated with these solvers.
8186
8287
This transcription computes the predictions by calling the continuous-time model in the
8388
equality constraint function and by using the implicit trapezoidal rule. It can handle
@@ -100,12 +105,14 @@ transcription method.
100105
struct TrapezoidalCollocation <: CollocationMethod
101106
h::Int
102107
nc::Int
103-
function TrapezoidalCollocation(h::Int=0)
108+
f_threads::Bool
109+
h_threads::Bool
110+
function TrapezoidalCollocation(h::Int=0; f_threads=false, h_threads=false)
104111
if !(h == 0 || h == 1)
105112
throw(ArgumentError("h argument must be 0 or 1 for TrapezoidalCollocation."))
106113
end
107114
nc = 2 # 2 collocation points per interval for trapezoidal rule
108-
return new(h, nc)
115+
return new(h, nc, f_threads, h_threads)
109116
end
110117
end
111118

@@ -1210,13 +1217,14 @@ in which ``\mathbf{x̂_0}`` is the augmented state extracted from the decision v
12101217
"""
12111218
function predict!(
12121219
Ŷ0, x̂0end, _, _, _,
1213-
mpc::PredictiveController, model::NonLinModel, ::TranscriptionMethod,
1220+
mpc::PredictiveController, model::NonLinModel, transcription::TranscriptionMethod,
12141221
_ , Z̃
12151222
)
12161223
nu, ny, nd, nx̂, Hp, Hc = model.nu, model.ny, model.nd, mpc.estim.nx̂, mpc.Hp, mpc.Hc
1224+
h_threads = transcription.h_threads
12171225
X̂0 = @views Z̃[(nu*Hc+1):(nu*Hc+nx̂*Hp)] # Z̃ = [ΔU; X̂0; ϵ]
12181226
D̂0 = mpc.D̂0
1219-
for j=1:Hp
1227+
@threadsif h_threads for j=1:Hp
12201228
x̂0 = @views X̂0[(1 + nx̂*(j-1)):(nx̂*j)]
12211229
d̂0 = @views D̂0[(1 + nd*(j-1)):(nd*j)]
12221230
ŷ0 = @views Ŷ0[(1 + ny*(j-1)):(ny*j)]
@@ -1332,14 +1340,15 @@ in which the augmented state ``\mathbf{x̂_0}`` are extracted from the decision
13321340
"""
13331341
function con_nonlinprogeq!(
13341342
geq, X̂0, Û0, K0,
1335-
mpc::PredictiveController, model::NonLinModel, ::MultipleShooting{T}, U0, Z̃
1336-
) where T
1343+
mpc::PredictiveController, model::NonLinModel, transcription::MultipleShooting, U0, Z̃
1344+
)
13371345
nu, nx̂, nd, nk = model.nu, mpc.estim.nx̂, model.nd, model.nk
13381346
Hp, Hc = mpc.Hp, mpc.Hc
13391347
nΔU, nX̂ = nu*Hc, nx̂*Hp
1348+
f_threads = transcription.f_threads
13401349
D̂0 = mpc.D̂0
13411350
X̂0_Z̃ = @views Z̃[(nΔU+1):(nΔU+nX̂)]
1342-
Threads.@threads for j=1:Hp
1351+
@threadsif f_threads for j=1:Hp
13431352
if j < 2
13441353
x̂0 = @views mpc.estim.x̂0[1:nx̂]
13451354
d̂0 = @views mpc.d0[1:nd]
@@ -1404,23 +1413,21 @@ function con_nonlinprogeq!(
14041413
nu, nx̂, nd, nx, h = model.nu, mpc.estim.nx̂, model.nd, model.nx, transcription.h
14051414
Hp, Hc = mpc.Hp, mpc.Hc
14061415
nΔU, nX̂ = nu*Hc, nx̂*Hp
1416+
f_threads = transcription.f_threads
14071417
Ts, p = model.Ts, model.p
14081418
As, Cs_u = mpc.estim.As, mpc.estim.Cs_u
14091419
nk = get_nk(model, transcription)
14101420
D̂0 = mpc.D̂0
14111421
X̂0_Z̃ = @views Z̃[(nΔU+1):(nΔU+nX̂)]
1412-
x̂0 = @views mpc.estim.x̂0[1:nx̂]
1413-
d̂0 = @views mpc.d0[1:nd]
1414-
if !iszero(h)
1415-
k1, u0, û0 = @views K0[1:nx], U0[1:nu], Û0[1:nu]
1416-
x0, xs = @views x̂0[1:nx], x̂0[nx+1:end]
1417-
mul!(û0, Cs_u, xs)
1418-
û0 .+= u0
1419-
model.f!(k1, x0, û0, d̂0, p)
1420-
lastk2 = k1
1421-
end
14221422
#TODO: allow parallel for loop or threads?
1423-
for j=1:Hp
1423+
@threadsif f_threads for j=1:Hp
1424+
if j < 2
1425+
x̂0 = @views mpc.estim.x̂0[1:nx̂]
1426+
d̂0 = @views mpc.d0[1:nd]
1427+
else
1428+
x̂0 = @views X̂0_Z̃[(1 + nx̂*(j-2)):(nx̂*(j-1))]
1429+
d̂0 = @views D̂0[(1 + nd*(j-2)):(nd*(j-1))]
1430+
end
14241431
k0 = @views K0[(1 + nk*(j-1)):(nk*j)]
14251432
d̂0next = @views D̂0[(1 + nd*(j-1)):(nd*j)]
14261433
x̂0next = @views X̂0[(1 + nx̂*(j-1)):(nx̂*j)]
@@ -1435,39 +1442,28 @@ function con_nonlinprogeq!(
14351442
mul!(xsnext, As, xs)
14361443
ssnext .= @. xsnext - xsnext_Z̃
14371444
# ----------------- deterministic defects --------------------------------------
1438-
if iszero(h) # piecewise constant manipulated inputs u:
1439-
u0 = @views U0[(1 + nu*(j-1)):(nu*j)]
1440-
û0 = @views Û0[(1 + nu*(j-1)):(nu*j)]
1441-
mul!(û0, Cs_u, xs) # ys_u(k) = Cs_u*xs(k)
1442-
û0 .+= u0 # û0(k) = u0(k) + ys_u(k)
1445+
u0 = @views U0[(1 + nu*(j-1)):(nu*j)]
1446+
û0 = @views Û0[(1 + nu*(j-1)):(nu*j)]
1447+
mul!(û0, Cs_u, xs) # ys_u(k) = Cs_u*xs(k)
1448+
û0 .+= u0 # û0(k) = u0(k) + ys_u(k)
1449+
if f_threads || h < 1 || j < 2
1450+
# we need to recompute k1 with multi-threading, even with h==1, since the
1451+
# last iteration (j-1) may not be executed (iterations are re-orderable)
14431452
model.f!(k1, x0, û0, d̂0, p)
1444-
model.f!(k2, x0next_Z̃, û0, d̂0next, p)
1445-
else # piecewise linear manipulated inputs u:
1446-
k1 .= lastk2
1447-
j == Hp && break # special case, treated after the loop
1453+
else
1454+
k1 .= @views K0[(1 + nk*(j-1)-nx):(nk*(j-1))] # k2 of of the last iter. j-1
1455+
end
1456+
if h < 1 || j Hp
1457+
# j = Hp special case: u(k+Hp-1) = u(k+Hp) since Hc ≤ Hp implies Δu(k+Hp) = 0
1458+
û0next = û0
1459+
else
14481460
u0next = @views U0[(1 + nu*j):(nu*(j+1))]
14491461
û0next = @views Û0[(1 + nu*j):(nu*(j+1))]
14501462
mul!(û0next, Cs_u, xsnext_Z̃) # ys_u(k+1) = Cs_u*xs(k+1)
14511463
û0next .+= u0next # û0(k+1) = u0(k+1) + ys_u(k+1)
1452-
model.f!(k2, x0next_Z̃, û0next, d̂0next, p)
1453-
lastk2 = k2
14541464
end
1455-
sdnext .= @. x0 - x0next_Z̃ + 0.5*Ts*(k1 + k2)
1456-
x̂0 = x̂0next_Z̃ # using states in Z̃ for next iteration (allow parallel for)
1457-
d̂0 = d̂0next
1458-
end
1459-
if !iszero(h)
1460-
# j = Hp special case: u(k+Hp-1) = u(k+Hp) since Hc ≤ Hp implies Δu(k+Hp)=0
1461-
x̂0, x̂0next_Z̃ = @views X̂0_Z̃[end-2nx̂+1:end-nx̂], X̂0_Z̃[end-nx̂+1:end]
1462-
k1, k2 = @views K0[end-2nx+1:end-nx], K0[end-nx+1:end] # k1 already filled
1463-
d̂0next = @views D̂0[end-nd+1:end]
1464-
û0next, u0next = @views Û0[end-nu+1:end], U0[end-nu+1:end] # correspond to u(k+Hp-1)
1465-
x0, x0next_Z̃, xsnext_Z̃ = @views x̂0[1:nx], x̂0next_Z̃[1:nx], x̂0next_Z̃[nx+1:end]
1466-
sdnext = @views geq[end-nx̂+1:end-nx̂+nx] # ssnext already filled
1467-
mul!(û0next, Cs_u, xsnext_Z̃)
1468-
û0next .+= u0next
14691465
model.f!(k2, x0next_Z̃, û0next, d̂0next, p)
1470-
sdnext .= @. x0 - x0next_Z̃ + (Ts/2)*(k1 + k2)
1466+
sdnext .= @. x0 - x0next_Z̃ + 0.5*Ts*(k1 + k2)
14711467
end
14721468
return geq
14731469
end

src/general.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,15 @@ function inv!(A::Hermitian{<:Real, <:AbstractMatrix})
122122
invA = inv(Achol)
123123
A .= Hermitian(invA, :L)
124124
return A
125+
end
126+
127+
"Add `Threads.@threads` to a `for` loop if `flag==true`, else leave the loop as is."
128+
macro threadsif(flag, expr)
129+
quote
130+
if $(flag)
131+
Threads.@threads $expr
132+
else
133+
$expr
134+
end
135+
end |> esc
125136
end

0 commit comments

Comments
 (0)