Skip to content

Commit 4a82a1a

Browse files
committed
added: assign! function to reduce allocations with BitVector
The field `mpc.con.i_g` is a `BitVector` and doing `b .= @views a[i]` is allocating when `i isa BitVector`.
1 parent 5389a4a commit 4a82a1a

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

src/controller/nonlinmpc.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
737737
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
738738
nk = get_nk(model, transcription)
739739
Hp, Hc = mpc.Hp, mpc.Hc
740-
ng, nc, neq = length(mpc.con.i_g), mpc.con.nc, mpc.con.neq
740+
ng, ng_i_g, nc, neq = length(mpc.con.i_g), sum(mpc.con.i_g), mpc.con.nc, mpc.con.neq
741741
nZ̃, nU, nŶ, nX̂, nK = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂, Hp*nk
742742
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
743743
strict = Val(true)
@@ -762,11 +762,9 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
762762
Cache(Û0), Cache(K0), Cache(X̂0),
763763
Cache(gc), Cache(geq),
764764
)
765-
## temporarily enable all the inequality constraints for sparsity detection:
766-
# mpc.con.i_g[1:end-nc] .= true
767765
∇g_prep = prepare_jacobian(g!, g, jac, Z̃_∇g, ∇g_context...; strict)
768-
# mpc.con.i_g[1:end-nc] .= false
769766
∇g = init_diffmat(JNT, jac, ∇g_prep, nZ̃, ng)
767+
∇g_i_g = ∇g[mpc.con.i_g, :]
770768
function update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
771769
if isdifferent(Z̃_arg, Z̃_∇g)
772770
Z̃_∇g .= Z̃_arg
@@ -776,19 +774,18 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
776774
end
777775
function gfunc_oracle!(g_arg, Z̃_arg)
778776
update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
779-
g_arg .= @views g[mpc.con.i_g]
777+
assign!(g_arg, g, mpc.con.i_g)
780778
return nothing
781779
end
782-
∇g_i_g = ∇g[mpc.con.i_g, :]
783780
function ∇gfunc_oracle!(∇g_arg, Z̃_arg)
784781
update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
785-
∇g_i_g .= @views ∇g[mpc.con.i_g, :]
782+
assign!(∇g_i_g, ∇g, mpc.con.i_g)
786783
diffmat2vec!(∇g_arg, ∇g_i_g)
787784
return nothing
788785
end
789-
g_min = fill(-myInf, sum(mpc.con.i_g))
790-
g_max = zeros(JNT, sum(mpc.con.i_g))
791-
∇g_structure = init_diffstructure(∇g[mpc.con.i_g, :])
786+
g_min = fill(-myInf, ng_i_g)
787+
g_max = zeros(JNT, ng_i_g)
788+
∇g_structure = init_diffstructure(∇g_i_g)
792789
g_oracle = Ipopt._VectorNonlinearOracle(;
793790
dimension = nZ̃,
794791
l = g_min,

src/general.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,28 @@ function backend_str(backend::AutoSparse)
8383
return str
8484
end
8585

86+
"Assign rows of of `g` to `gi` at the positions where `i` is true (without allocations)."
87+
function assign!(Ai::AbstractMatrix, A::AbstractMatrix, i::BitVector)
88+
k = 1
89+
@inbounds for j in axes(A, 1)
90+
if i[j]
91+
Ai[k, :] = @views A[j, :]
92+
k += 1
93+
end
94+
end
95+
return Ai
96+
end
97+
function assign!(ai::AbstractVector, a::AbstractVector, i::BitVector)
98+
k = 1
99+
@inbounds for j in eachindex(a)
100+
if i[j]
101+
ai[k] = a[j]
102+
k += 1
103+
end
104+
end
105+
return ai
106+
end
107+
86108
"Verify that x and y elements are different using `!==`."
87109
isdifferent(x, y) = any(xi !== yi for (xi, yi) in zip(x, y))
88110

0 commit comments

Comments
 (0)