Skip to content

Commit bffa05a

Browse files
committed
changed: two separate functions for objective and constraints
1 parent aefc4ff commit bffa05a

File tree

1 file changed

+64
-28
lines changed

1 file changed

+64
-28
lines changed

src/controller/nonlinmpc.jl

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,9 @@ function init_optimization!(
589589
end
590590
end
591591
if mpc.oracle
592-
g_oracle, geq_oracle, J_op = get_nonlinops(mpc, optim)
593-
optim[:J_op] = J_op
592+
J_op = get_nonlinobj_op(mpc, optim)
593+
g_oracle, geq_oracle = get_nonlincon_oracle(mpc, optim)
594+
594595
else
595596
J_func, ∇J_func!, g_funcs, ∇g_funcs!, geq_funcs, ∇geq_funcs! = get_optim_functions(
596597
mpc, optim
@@ -616,38 +617,30 @@ Re-construct nonlinear constraints and add them to `mpc.optim`.
616617
"""
617618
function reset_nonlincon!(mpc::NonLinMPC)
618619
if mpc.oracle
619-
g_oracle, geq_oracle = get_nonlinops(mpc, mpc.optim)
620+
g_oracle, geq_oracle = get_nonlincon_oracle(mpc, mpc.optim)
620621
set_nonlincon!(mpc, mpc.optim, g_oracle, geq_oracle)
621622
else
622623
set_nonlincon_leg!(mpc, mpc.estim.model, mpc.transcription, mpc.optim)
623624
end
624625
end
625626

626627
"""
627-
get_nonlinops(mpc::NonLinMPC, optim) -> g_oracle, geq_oracle, J_op
628+
get_nonlincon_oracle(mpc::NonLinMPC, optim) -> g_oracle, geq_oracle
628629
629-
Return the operators for the nonlinear optimization of `mpc` [`NonLinMPC`](@ref) controller.
630+
Return the nonlinear constraint oracles for [`NonLinMPC`](@ref) `mpc`.
630631
631632
Return `g_oracle` and `geq_oracle`, the inequality and equality [`VectorNonlinearOracle`](@extref MathOptInterface MathOptInterface.VectorNonlinearOracle)
632633
for the two respective constraints. Note that `g_oracle` only includes the non-`Inf`
633-
inequality constraints, thus it must be re-constructed if they change. Also return `J_op`,
634-
the [`NonlinearOperator`](@extref JuMP NonlinearOperator) for the objective function, based
635-
on the splatting syntax. This method is really intricate and that's because of 3 elements:
636-
637-
- These functions are used inside the nonlinear optimization, so they must be type-stable
638-
and as efficient as possible. All the function outputs and derivatives are cached and
639-
updated in-place if required to use the efficient [`value_and_jacobian!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
640-
- The splatting syntax for objective functions implies the use of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing))
641-
and memoization to avoid redundant computations. This is already complex, but it's even
642-
worse knowing that the automatic differentiation tools do not support splatting.
643-
- The signature of gradient and hessian functions is not the same for univariate (`nZ̃ == 1`)
644-
and multivariate (`nZ̃ > 1`) operators in `JuMP`. Both must be defined.
634+
inequality constraints, thus it must be re-constructed if they change. This method is really
635+
intricate because the oracles are used inside the nonlinear optimization, so they must be
636+
type-stable and as efficient as possible. All the function outputs and derivatives are
637+
ached and updated in-place if required to use the efficient [`value_and_jacobian!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
645638
"""
646-
function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<:Real
639+
function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
647640
# ----------- common cache for all functions ----------------------------------------
648641
model = mpc.estim.model
649642
transcription = mpc.transcription
650-
grad, jac, hess = mpc.gradient, mpc.jacobian, mpc.hessian
643+
jac, hess = mpc.jacobian, mpc.hessian
651644
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
652645
nk = get_nk(model, transcription)
653646
Hp, Hc = mpc.Hp, mpc.Hc
@@ -658,7 +651,6 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
658651
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
659652
strict = Val(true)
660653
myNaN, myInf = convert(JNT, NaN), convert(JNT, Inf)
661-
J::Vector{JNT} = zeros(JNT, 1)
662654
ΔŨ::Vector{JNT} = zeros(JNT, nΔŨ)
663655
x̂0end::Vector{JNT} = zeros(JNT, nx̂)
664656
K0::Vector{JNT} = zeros(JNT, nK)
@@ -667,7 +659,7 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
667659
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
668660
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
669661
gi::Vector{JNT}, geq::Vector{JNT} = zeros(JNT, ngi), zeros(JNT, neq)
670-
λi::Vector{JNT}, λeq::Vector{JNT} = zeros(JNT, ngi), zeros(JNT, neq)
662+
λi::Vector{JNT}, λeq::Vector{JNT} = ones(JNT, ngi), ones(JNT, neq)
671663
# -------------- inequality constraint: nonlinear oracle -----------------------------
672664
function gi!(gi, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
673665
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -694,7 +686,9 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
694686
Cache(Û0), Cache(K0), Cache(X̂0),
695687
Cache(gc), Cache(geq), Cache(g), Cache(gi)
696688
)
697-
∇²gi_prep = prepare_hessian(ℓ_gi, hess, Z̃_∇gi, Constant(λi), ∇²gi_context...; strict)
689+
∇²gi_prep = prepare_hessian(
690+
ℓ_gi, hess, Z̃_∇gi, Constant(λi), ∇²gi_context...; strict
691+
)
698692
∇²ℓ_gi = init_diffmat(JNT, hess, ∇²gi_prep, nZ̃, nZ̃)
699693
∇²gi_structure = lowertriangle_indices(init_diffstructure(∇²ℓ_gi))
700694
end
@@ -755,7 +749,9 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
755749
Cache(Û0), Cache(K0), Cache(X̂0),
756750
Cache(gc), Cache(geq), Cache(g)
757751
)
758-
∇²geq_prep = prepare_hessian(ℓ_geq, hess, Z̃_∇geq, Constant(λeq), ∇²geq_context...; strict)
752+
∇²geq_prep = prepare_hessian(
753+
ℓ_geq, hess, Z̃_∇geq, Constant(λeq), ∇²geq_context...; strict
754+
)
759755
∇²ℓ_geq = init_diffmat(JNT, hess, ∇²geq_prep, nZ̃, nZ̃)
760756
∇²geq_structure = lowertriangle_indices(init_diffstructure(∇²ℓ_geq))
761757
end
@@ -791,7 +787,47 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
791787
hessian_lagrangian_structure = isnothing(hess) ? Tuple{Int,Int}[] : ∇²geq_structure,
792788
eval_hessian_lagrangian = isnothing(hess) ? nothing : ∇²geq_func!
793789
)
794-
# ------------- objective function: splatting syntax ---------------------------------
790+
return g_oracle, geq_oracle
791+
end
792+
793+
"""
794+
get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) -> J_op
795+
796+
Return the nonlinear operator for the objective function of `mpc` [`NonLinMPC`](@ref).
797+
798+
It is based on the splatting syntax. This method is really intricate and that's because of:
799+
800+
- These functions are used inside the nonlinear optimization, so they must be type-stable
801+
and as efficient as possible. All the function outputs and derivatives are cached and
802+
updated in-place if required to use the efficient [`value_and_gradient!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
803+
- The splatting syntax for objective functions implies the use of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing))
804+
and memoization to avoid redundant computations. This is already complex, but it's even
805+
worse knowing that the automatic differentiation tools do not support splatting.
806+
- The signature of gradient and hessian functions is not the same for univariate (`nZ̃ == 1`)
807+
and multivariate (`nZ̃ > 1`) operators in `JuMP`. Both must be defined.
808+
"""
809+
function get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<:Real
810+
model = mpc.estim.model
811+
transcription = mpc.transcription
812+
grad, hess = mpc.gradient, mpc.hessian
813+
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
814+
nk = get_nk(model, transcription)
815+
Hp, Hc = mpc.Hp, mpc.Hc
816+
ng = length(mpc.con.i_g)
817+
nc, neq = mpc.con.nc, mpc.con.neq
818+
nZ̃, nU, nŶ, nX̂, nK = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂, Hp*nk
819+
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
820+
strict = Val(true)
821+
myNaN = convert(JNT, NaN)
822+
J::Vector{JNT} = zeros(JNT, 1)
823+
ΔŨ::Vector{JNT} = zeros(JNT, nΔŨ)
824+
x̂0end::Vector{JNT} = zeros(JNT, nx̂)
825+
K0::Vector{JNT} = zeros(JNT, nK)
826+
Ue::Vector{JNT}, Ŷe::Vector{JNT} = zeros(JNT, nUe), zeros(JNT, nŶe)
827+
U0::Vector{JNT}, Ŷ0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nŶ)
828+
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
829+
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
830+
geq::Vector{JNT} = zeros(JNT, neq)
795831
function J!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
796832
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
797833
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
@@ -870,12 +906,12 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
870906
return fill_lowertriangle!(∇²J_arg, ∇²J)
871907
end
872908
end
873-
J_op = if !isnothing(hess)
874-
JuMP.add_nonlinear_operator(optim, nZ̃, J_func, ∇J_func!, ∇²J_func!, name=:J_op)
909+
if !isnothing(hess)
910+
@operator(optim, J_op, nZ̃, J_func, ∇J_func!, ∇²J_func!)
875911
else
876-
JuMP.add_nonlinear_operator(optim, nZ̃, J_func, ∇J_func!, name=:J_op)
912+
@operator(optim, J_op, nZ̃, J_func, ∇J_func!)
877913
end
878-
return g_oracle, geq_oracle, J_op
914+
return J_op
879915
end
880916

881917
"""

0 commit comments

Comments
 (0)