Skip to content

Commit ef67344

Browse files
committed
debug: use prep in NonLinMPC objective hessian
1 parent 6fdbce5 commit ef67344

File tree

4 files changed

+332
-218
lines changed

4 files changed

+332
-218
lines changed

src/controller/nonlinmpc.jl

Lines changed: 131 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ function NonLinMPC(
409409
validate_JE(NT, JE)
410410
gc! = get_mutating_gc(NT, gc)
411411
weights = ControllerWeights(estim.model, Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, Ewt)
412-
hessian = validate_hessian(hessian, gradient, oracle)
412+
hessian = validate_hessian(hessian, gradient, oracle, DEFAULT_NONLINMPC_HESSIAN)
413413
return NonLinMPC{NT}(
414414
estim, Hp, Hc, nb, weights, JE, gc!, nc, p,
415415
transcription, optim, gradient, jacobian, hessian, oracle
@@ -512,34 +512,6 @@ function test_custom_functions(NT, model::SimModel, JE, gc!, nc, Uop, Yop, Dop,
512512
return nothing
513513
end
514514

515-
"""
516-
validate_hessian(hessian, gradient, oracle) -> backend
517-
518-
Validate `hessian` argument and return the differentiation backend.
519-
"""
520-
function validate_hessian(hessian, gradient, oracle)
521-
if hessian == true
522-
backend = DEFAULT_NONLINMPC_HESSIAN
523-
elseif hessian == false || isnothing(hessian)
524-
backend = nothing
525-
else
526-
backend = hessian
527-
end
528-
if oracle == false && !isnothing(backend)
529-
error("Second order derivatives are only supported with oracle=true.")
530-
end
531-
if oracle == true && !isnothing(backend)
532-
hess = dense_backend(backend)
533-
grad = dense_backend(gradient)
534-
if hess != grad
535-
@info "The objective function gradient will be computed with the hessian "*
536-
"backend ($(backend_str(hess)))\n instead of the one in gradient "*
537-
"argument ($(backend_str(grad))) for efficiency."
538-
end
539-
end
540-
return backend
541-
end
542-
543515
"""
544516
addinfo!(info, mpc::NonLinMPC) -> info
545517
@@ -626,6 +598,134 @@ function reset_nonlincon!(mpc::NonLinMPC)
626598
end
627599
end
628600

601+
"""
602+
get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) -> J_op
603+
604+
Return the nonlinear operator for the objective of `mpc` [`NonLinMPC`](@ref).
605+
606+
It is based on the splatting syntax. This method is really intricate and that's because of:
607+
608+
- These functions are used inside the nonlinear optimization, so they must be type-stable
609+
and as efficient as possible. All the function outputs and derivatives are cached and
610+
updated in-place if required to use the efficient [`value_and_gradient!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
611+
- The splatting syntax for objective functions implies the use of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing))
612+
and memoization to avoid redundant computations. This is already complex, but it's even
613+
worse knowing that the automatic differentiation tools do not support splatting.
614+
- The signature of gradient and hessian functions is not the same for univariate (`nZ̃ == 1`)
615+
and multivariate (`nZ̃ > 1`) operators in `JuMP`. Both must be defined.
616+
"""
617+
function get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<:Real
618+
model = mpc.estim.model
619+
transcription = mpc.transcription
620+
grad, hess = mpc.gradient, mpc.hessian
621+
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
622+
nk = get_nk(model, transcription)
623+
Hp, Hc = mpc.Hp, mpc.Hc
624+
ng = length(mpc.con.i_g)
625+
nc, neq = mpc.con.nc, mpc.con.neq
626+
nZ̃, nU, nŶ, nX̂, nK = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂, Hp*nk
627+
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
628+
strict = Val(true)
629+
myNaN = convert(JNT, NaN)
630+
J::Vector{JNT} = zeros(JNT, 1)
631+
ΔŨ::Vector{JNT} = zeros(JNT, nΔŨ)
632+
x̂0end::Vector{JNT} = zeros(JNT, nx̂)
633+
K0::Vector{JNT} = zeros(JNT, nK)
634+
Ue::Vector{JNT}, Ŷe::Vector{JNT} = zeros(JNT, nUe), zeros(JNT, nŶe)
635+
U0::Vector{JNT}, Ŷ0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nŶ)
636+
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
637+
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
638+
geq::Vector{JNT} = zeros(JNT, neq)
639+
function J!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
640+
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
641+
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
642+
end
643+
Z̃_J = fill(myNaN, nZ̃) # NaN to force update at first call
644+
J_context = (
645+
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
646+
Cache(Û0), Cache(K0), Cache(X̂0),
647+
Cache(gc), Cache(g), Cache(geq),
648+
)
649+
∇J_prep = prepare_gradient(J!, grad, Z̃_J, J_context...; strict)
650+
∇J = Vector{JNT}(undef, nZ̃)
651+
if !isnothing(hess)
652+
∇²J_prep = prepare_hessian(J!, hess, Z̃_J, J_context...; strict)
653+
∇²J = init_diffmat(JNT, hess, ∇²J_prep, nZ̃, nZ̃)
654+
end
655+
update_objective! = if !isnothing(hess)
656+
function (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
657+
if isdifferent(Z̃_arg, Z̃_J)
658+
Z̃_J .= Z̃_arg
659+
J[], _ = value_gradient_and_hessian!(
660+
J!, ∇J, ∇²J, ∇²J_prep, hess, Z̃_J, J_context...
661+
)
662+
end
663+
end
664+
else
665+
update_objective! = function (J, ∇J, Z̃_∇J, Z̃_arg)
666+
if isdifferent(Z̃_arg, Z̃_∇J)
667+
Z̃_∇J .= Z̃_arg
668+
J[], _ = value_and_gradient!(
669+
J!, ∇J, ∇J_prep, grad, Z̃_∇J, J_context...
670+
)
671+
end
672+
end
673+
end
674+
J_func = if !isnothing(hess)
675+
function (Z̃_arg::Vararg{T, N}) where {N, T<:Real}
676+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
677+
return J[]::T
678+
end
679+
else
680+
function (Z̃_arg::Vararg{T, N}) where {N, T<:Real}
681+
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
682+
return J[]::T
683+
end
684+
end
685+
∇J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
686+
if !isnothing(hess)
687+
function (Z̃_arg)
688+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
689+
return ∇J[]
690+
end
691+
else
692+
function (Z̃_arg)
693+
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
694+
return ∇J[]
695+
end
696+
end
697+
else # multivariate syntax (see JuMP.@operator doc):
698+
if !isnothing(hess)
699+
function (∇J_arg::AbstractVector{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
700+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
701+
return ∇J_arg .= ∇J
702+
end
703+
else
704+
function (∇J_arg::AbstractVector{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
705+
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
706+
return ∇J_arg .= ∇J
707+
end
708+
end
709+
end
710+
∇²J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
711+
function (Z̃_arg)
712+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
713+
return ∇²J[]
714+
end
715+
else # multivariate syntax (see JuMP.@operator doc):
716+
function (∇²J_arg::AbstractMatrix{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
717+
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
718+
return fill_lowertriangle!(∇²J_arg, ∇²J)
719+
end
720+
end
721+
if !isnothing(hess)
722+
@operator(optim, J_op, nZ̃, J_func, ∇J_func!, ∇²J_func!)
723+
else
724+
@operator(optim, J_op, nZ̃, J_func, ∇J_func!)
725+
end
726+
return J_op
727+
end
728+
629729
"""
630730
get_nonlincon_oracle(mpc::NonLinMPC, optim) -> g_oracle, geq_oracle
631731
@@ -661,7 +761,7 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
661761
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
662762
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
663763
gi::Vector{JNT}, geq::Vector{JNT} = zeros(JNT, ngi), zeros(JNT, neq)
664-
λi::Vector{JNT}, λeq::Vector{JNT} = ones(JNT, ngi), ones(JNT, neq)
764+
λi::Vector{JNT}, λeq::Vector{JNT} = ones(JNT, ngi), ones(JNT, neq)
665765
# -------------- inequality constraint: nonlinear oracle -----------------------------
666766
function gi!(gi, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
667767
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -725,7 +825,7 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
725825
jacobian_structure = ∇gi_structure,
726826
eval_jacobian = ∇gi_func!,
727827
hessian_lagrangian_structure = isnothing(hess) ? Tuple{Int,Int}[] : ∇²gi_structure,
728-
eval_hessian_lagrangian = isnothing(hess) ? nothing : ∇²gi_func!
828+
eval_hessian_lagrangian = isnothing(hess) ? nothing : ∇²gi_func!
729829
)
730830
# ------------- equality constraints : nonlinear oracle ------------------------------
731831
function geq!(geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
@@ -792,130 +892,6 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
792892
return g_oracle, geq_oracle
793893
end
794894

795-
"""
796-
get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) -> J_op
797-
798-
Return the nonlinear operator for the objective function of `mpc` [`NonLinMPC`](@ref).
799-
800-
It is based on the splatting syntax. This method is really intricate and that's because of:
801-
802-
- These functions are used inside the nonlinear optimization, so they must be type-stable
803-
and as efficient as possible. All the function outputs and derivatives are cached and
804-
updated in-place if required to use the efficient [`value_and_gradient!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
805-
- The splatting syntax for objective functions implies the use of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing))
806-
and memoization to avoid redundant computations. This is already complex, but it's even
807-
worse knowing that the automatic differentiation tools do not support splatting.
808-
- The signature of gradient and hessian functions is not the same for univariate (`nZ̃ == 1`)
809-
and multivariate (`nZ̃ > 1`) operators in `JuMP`. Both must be defined.
810-
"""
811-
function get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<:Real
812-
model = mpc.estim.model
813-
transcription = mpc.transcription
814-
grad, hess = mpc.gradient, mpc.hessian
815-
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
816-
nk = get_nk(model, transcription)
817-
Hp, Hc = mpc.Hp, mpc.Hc
818-
ng = length(mpc.con.i_g)
819-
nc, neq = mpc.con.nc, mpc.con.neq
820-
nZ̃, nU, nŶ, nX̂, nK = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂, Hp*nk
821-
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
822-
strict = Val(true)
823-
myNaN = convert(JNT, NaN)
824-
J::Vector{JNT} = zeros(JNT, 1)
825-
ΔŨ::Vector{JNT} = zeros(JNT, nΔŨ)
826-
x̂0end::Vector{JNT} = zeros(JNT, nx̂)
827-
K0::Vector{JNT} = zeros(JNT, nK)
828-
Ue::Vector{JNT}, Ŷe::Vector{JNT} = zeros(JNT, nUe), zeros(JNT, nŶe)
829-
U0::Vector{JNT}, Ŷ0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nŶ)
830-
Û0::Vector{JNT}, X̂0::Vector{JNT} = zeros(JNT, nU), zeros(JNT, nX̂)
831-
gc::Vector{JNT}, g::Vector{JNT} = zeros(JNT, nc), zeros(JNT, ng)
832-
geq::Vector{JNT} = zeros(JNT, neq)
833-
function J!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
834-
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
835-
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
836-
end
837-
Z̃_J = fill(myNaN, nZ̃) # NaN to force update at first call
838-
J_context = (
839-
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
840-
Cache(Û0), Cache(K0), Cache(X̂0),
841-
Cache(gc), Cache(g), Cache(geq),
842-
)
843-
∇J_prep = prepare_gradient(J!, grad, Z̃_J, J_context...; strict)
844-
∇J = Vector{JNT}(undef, nZ̃)
845-
if !isnothing(hess)
846-
∇²J_prep = prepare_hessian(J!, hess, Z̃_J, J_context...; strict)
847-
∇²J = init_diffmat(JNT, hess, ∇²J_prep, nZ̃, nZ̃)
848-
end
849-
update_objective! = if !isnothing(hess)
850-
function (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
851-
if isdifferent(Z̃_arg, Z̃_J)
852-
Z̃_J .= Z̃_arg
853-
J[], _ = value_gradient_and_hessian!(J!, ∇J, ∇²J, hess, Z̃_J, J_context...)
854-
end
855-
end
856-
else
857-
update_objective! = function (J, ∇J, Z̃_∇J, Z̃_arg)
858-
if isdifferent(Z̃_arg, Z̃_∇J)
859-
Z̃_∇J .= Z̃_arg
860-
J[], _ = value_and_gradient!(J!, ∇J, ∇J_prep, grad, Z̃_∇J, J_context...)
861-
end
862-
end
863-
end
864-
J_func = if !isnothing(hess)
865-
function (Z̃_arg::Vararg{T, N}) where {N, T<:Real}
866-
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
867-
return J[]::T
868-
end
869-
else
870-
function (Z̃_arg::Vararg{T, N}) where {N, T<:Real}
871-
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
872-
return J[]::T
873-
end
874-
end
875-
∇J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
876-
if !isnothing(hess)
877-
function (Z̃_arg)
878-
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
879-
return ∇J[]
880-
end
881-
else
882-
function (Z̃_arg)
883-
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
884-
return ∇J[]
885-
end
886-
end
887-
else # multivariate syntax (see JuMP.@operator doc):
888-
if !isnothing(hess)
889-
function (∇J_arg::AbstractVector{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
890-
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
891-
return ∇J_arg .= ∇J
892-
end
893-
else
894-
function (∇J_arg::AbstractVector{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
895-
update_objective!(J, ∇J, Z̃_J, Z̃_arg)
896-
return ∇J_arg .= ∇J
897-
end
898-
end
899-
end
900-
∇²J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
901-
function (Z̃_arg)
902-
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
903-
return ∇²J[]
904-
end
905-
else # multivariate syntax (see JuMP.@operator doc):
906-
function (∇²J_arg::AbstractMatrix{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
907-
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
908-
return fill_lowertriangle!(∇²J_arg, ∇²J)
909-
end
910-
end
911-
if !isnothing(hess)
912-
@operator(optim, J_op, nZ̃, J_func, ∇J_func!, ∇²J_func!)
913-
else
914-
@operator(optim, J_op, nZ̃, J_func, ∇J_func!)
915-
end
916-
return J_op
917-
end
918-
919895
"""
920896
update_predictions!(
921897
ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq,

src/estimator/mhe.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ end
1313
function print_backends(io::IO, estim::MovingHorizonEstimator, ::SimModel)
1414
println(io, "├ gradient: $(backend_str(estim.gradient))")
1515
println(io, "├ jacobian: $(backend_str(estim.jacobian))")
16+
println(io, "├ hessian: $(backend_str(estim.hessian))")
1617
end
1718
"No differentiation backends to print for `LinModel`."
1819
print_backends(::IO, ::MovingHorizonEstimator, ::LinModel) = nothing

0 commit comments

Comments
 (0)