Skip to content

Commit 82b70f9

Browse files
committed
changed: trunk g vector earlier
This is more efficient with dense differentiation, in theory. I will compare the perf. in a separate PR.
1 parent 65693cf commit 82b70f9

File tree

1 file changed

+23
-29
lines changed

1 file changed

+23
-29
lines changed

src/controller/nonlinmpc.jl

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,8 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
737737
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
738738
nk = get_nk(model, transcription)
739739
Hp, Hc = mpc.Hp, mpc.Hc
740-
ng, nc, neq = length(mpc.con.i_g), mpc.con.nc, mpc.con.neq
740+
ng, ng_i_g = length(mpc.con.i_g), sum(mpc.con.i_g)
741+
nc, neq = mpc.con.nc, mpc.con.neq
741742
nZ̃, nU, nŶ, nX̂, nK = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂, Hp*nk
742743
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
743744
strict = Val(true)
@@ -750,45 +751,40 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
750751
U0::Vector{JNT}, Ŷ0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nŶ)
751752
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
752753
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
754+
g_i_g::Vector{JNT} = zeros(JNT, ng_i_g)
753755
geq::Vector{JNT} = zeros(JNT, neq)
754756
# -------------- inequality constraint: nonlinear oracle -----------------------------
755-
function g!(g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq)
757+
function g_i_g!(g_i_g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
756758
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
759+
g_i_g .= @views g[mpc.con.i_g]
757760
return nothing
758761
end
759762
Z̃_∇g = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
760763
∇g_context = (
761764
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
762765
Cache(Û0), Cache(K0), Cache(X̂0),
763-
Cache(gc), Cache(geq),
766+
Cache(gc), Cache(geq), Cache(g)
764767
)
765-
## temporarily enable all the inequality constraints for sparsity detection:
766-
# mpc.con.i_g[1:end-nc] .= true
767-
∇g_prep = prepare_jacobian(g!, g, jac, Z̃_∇g, ∇g_context...; strict)
768-
# mpc.con.i_g[1:end-nc] .= false
769-
∇g = init_diffmat(JNT, jac, ∇g_prep, nZ̃, ng)
770-
function update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
768+
∇g_prep = prepare_jacobian(g_i_g!, g_i_g, jac, Z̃_∇g, ∇g_context...; strict)
769+
∇g_i_g = init_diffmat(JNT, jac, ∇g_prep, nZ̃, ng)
770+
function update_con!(g_i_g, ∇g_i_g, Z̃_∇g, Z̃_arg)
771771
if isdifferent(Z̃_arg, Z̃_∇g)
772772
Z̃_∇g .= Z̃_arg
773-
value_and_jacobian!(g!, g, ∇g, ∇g_prep, jac, Z̃_∇g, ∇g_context...)
773+
value_and_jacobian!(g_i_g!, g_i_g, ∇g_i_g, ∇g_prep, jac, Z̃_∇g, ∇g_context...)
774774
end
775775
return nothing
776776
end
777-
function gfunc_oracle!(g_arg, Z̃_arg)
778-
update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
779-
g_arg .= @views g[mpc.con.i_g]
780-
return nothing
777+
function gfunc_oracle!(g_vec, Z̃_arg)
778+
update_con!(g_i_g, ∇g_i_g, Z̃_∇g, Z̃_arg)
779+
return g_vec .= g_i_g
781780
end
782-
∇g_i_g = ∇g[mpc.con.i_g, :]
783-
function ∇gfunc_oracle!(∇g_arg, Z̃_arg)
784-
update_con!(g, ∇g, Z̃_∇g, Z̃_arg)
785-
∇g_i_g .= @views ∇g[mpc.con.i_g, :]
786-
diffmat2vec!(∇g_arg, ∇g_i_g)
787-
return nothing
781+
function ∇gfunc_oracle!(∇g_vec, Z̃_arg)
782+
update_con!(g_i_g, ∇g_i_g, Z̃_∇g, Z̃_arg)
783+
return diffmat2vec!(∇g_vec, ∇g_i_g)
788784
end
789-
g_min = fill(-myInf, sum(mpc.con.i_g))
790-
g_max = zeros(JNT, sum(mpc.con.i_g))
791-
∇g_structure = init_diffstructure(∇g[mpc.con.i_g, :])
785+
g_min = fill(-myInf, ng_i_g)
786+
g_max = zeros(JNT, ng_i_g)
787+
∇g_structure = init_diffstructure(∇g_i_g)
792788
g_oracle = Ipopt._VectorNonlinearOracle(;
793789
dimension = nZ̃,
794790
l = g_min,
@@ -817,15 +813,13 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
817813
end
818814
return nothing
819815
end
820-
function geq_oracle!(geq_arg, Z̃_arg)
816+
function geq_oracle!(geq_vec, Z̃_arg)
821817
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃_arg)
822-
geq_arg .= geq
823-
return nothing
818+
return geq_vec .= geq
824819
end
825-
function ∇geq_oracle!(∇geq_arg, Z̃_arg)
820+
function ∇geq_oracle!(∇geq_vec, Z̃_arg)
826821
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃_arg)
827-
diffmat2vec!(∇geq_arg, ∇geq)
828-
return nothing
822+
return diffmat2vec!(∇geq_vec, ∇geq)
829823
end
830824
geq_min = geq_max = zeros(JNT, neq)
831825
∇geq_structure = init_diffstructure(∇geq)

0 commit comments

Comments
 (0)