@@ -409,7 +409,7 @@ function NonLinMPC(
409409 validate_JE (NT, JE)
410410 gc! = get_mutating_gc (NT, gc)
411411 weights = ControllerWeights (estim. model, Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, Ewt)
412- hessian = validate_hessian (hessian, gradient, oracle)
412+ hessian = validate_hessian (hessian, gradient, oracle, DEFAULT_NONLINMPC_HESSIAN )
413413 return NonLinMPC {NT} (
414414 estim, Hp, Hc, nb, weights, JE, gc!, nc, p,
415415 transcription, optim, gradient, jacobian, hessian, oracle
@@ -512,34 +512,6 @@ function test_custom_functions(NT, model::SimModel, JE, gc!, nc, Uop, Yop, Dop,
512512 return nothing
513513end
514514
515- """
516- validate_hessian(hessian, gradient, oracle) -> backend
517-
518- Validate `hessian` argument and return the differentiation backend.
519- """
520- function validate_hessian (hessian, gradient, oracle)
521- if hessian == true
522- backend = DEFAULT_NONLINMPC_HESSIAN
523- elseif hessian == false || isnothing (hessian)
524- backend = nothing
525- else
526- backend = hessian
527- end
528- if oracle == false && ! isnothing (backend)
529- error (" Second order derivatives are only supported with oracle=true." )
530- end
531- if oracle == true && ! isnothing (backend)
532- hess = dense_backend (backend)
533- grad = dense_backend (gradient)
534- if hess != grad
535- @info " The objective function gradient will be computed with the hessian " *
536- " backend ($(backend_str (hess)) )\n instead of the one in gradient " *
537- " argument ($(backend_str (grad)) ) for efficiency."
538- end
539- end
540- return backend
541- end
542-
543515"""
544516 addinfo!(info, mpc::NonLinMPC) -> info
545517
@@ -626,6 +598,134 @@ function reset_nonlincon!(mpc::NonLinMPC)
626598 end
627599end
628600
601+ """
602+ get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) -> J_op
603+
604+ Return the nonlinear operator for the objective of `mpc` [`NonLinMPC`](@ref).
605+
606+ It is based on the splatting syntax. This method is really intricate and that's because of:
607+
608+ - These functions are used inside the nonlinear optimization, so they must be type-stable
609+ and as efficient as possible. All the function outputs and derivatives are cached and
610+ updated in-place if required to use the efficient [`value_and_gradient!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
611+ - The splatting syntax for objective functions implies the use of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing))
612+ and memoization to avoid redundant computations. This is already complex, but it's even
613+ worse knowing that the automatic differentiation tools do not support splatting.
614+ - The signature of gradient and hessian functions is not the same for univariate (`nZ̃ == 1`)
615+ and multivariate (`nZ̃ > 1`) operators in `JuMP`. Both must be defined.
616+ """
617+ function get_nonlinobj_op (mpc:: NonLinMPC , optim:: JuMP.GenericModel{JNT} ) where JNT<: Real
618+ model = mpc. estim. model
619+ transcription = mpc. transcription
620+ grad, hess = mpc. gradient, mpc. hessian
621+ nu, ny, nx̂, nϵ = model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ
622+ nk = get_nk (model, transcription)
623+ Hp, Hc = mpc. Hp, mpc. Hc
624+ ng = length (mpc. con. i_g)
625+ nc, neq = mpc. con. nc, mpc. con. neq
626+ nZ̃, nU, nŶ, nX̂, nK = length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂, Hp* nk
627+ nΔŨ, nUe, nŶe = nu* Hc + nϵ, nU + nu, nŶ + ny
628+ strict = Val (true )
629+ myNaN = convert (JNT, NaN )
630+ J:: Vector{JNT} = zeros (JNT, 1 )
631+ ΔŨ:: Vector{JNT} = zeros (JNT, nΔŨ)
632+ x̂0end:: Vector{JNT} = zeros (JNT, nx̂)
633+ K0:: Vector{JNT} = zeros (JNT, nK)
634+ Ue:: Vector{JNT} , Ŷe:: Vector{JNT} = zeros (JNT, nUe), zeros (JNT, nŶe)
635+ U0:: Vector{JNT} , Ŷ0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nŶ)
636+ Û0:: Vector{JNT} , X̂0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nX̂)
637+ gc:: Vector{JNT} , g:: Vector{JNT} = zeros (JNT, nc), zeros (JNT, ng)
638+ geq:: Vector{JNT} = zeros (JNT, neq)
639+ function J! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
640+ update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
641+ return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
642+ end
643+ Z̃_J = fill (myNaN, nZ̃) # NaN to force update at first call
644+ J_context = (
645+ Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
646+ Cache (Û0), Cache (K0), Cache (X̂0),
647+ Cache (gc), Cache (g), Cache (geq),
648+ )
649+ ∇J_prep = prepare_gradient (J!, grad, Z̃_J, J_context... ; strict)
650+ ∇J = Vector {JNT} (undef, nZ̃)
651+ if ! isnothing (hess)
652+ ∇²J_prep = prepare_hessian (J!, hess, Z̃_J, J_context... ; strict)
653+ ∇²J = init_diffmat (JNT, hess, ∇²J_prep, nZ̃, nZ̃)
654+ end
655+ update_objective! = if ! isnothing (hess)
656+ function (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
657+ if isdifferent (Z̃_arg, Z̃_J)
658+ Z̃_J .= Z̃_arg
659+ J[], _ = value_gradient_and_hessian! (
660+ J!, ∇J, ∇²J, ∇²J_prep, hess, Z̃_J, J_context...
661+ )
662+ end
663+ end
664+ else
665+ update_objective! = function (J, ∇J, Z̃_∇J, Z̃_arg)
666+ if isdifferent (Z̃_arg, Z̃_∇J)
667+ Z̃_∇J .= Z̃_arg
668+ J[], _ = value_and_gradient! (
669+ J!, ∇J, ∇J_prep, grad, Z̃_∇J, J_context...
670+ )
671+ end
672+ end
673+ end
674+ J_func = if ! isnothing (hess)
675+ function (Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
676+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
677+ return J[]:: T
678+ end
679+ else
680+ function (Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
681+ update_objective! (J, ∇J, Z̃_J, Z̃_arg)
682+ return J[]:: T
683+ end
684+ end
685+ ∇J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
686+ if ! isnothing (hess)
687+ function (Z̃_arg)
688+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
689+ return ∇J[]
690+ end
691+ else
692+ function (Z̃_arg)
693+ update_objective! (J, ∇J, Z̃_J, Z̃_arg)
694+ return ∇J[]
695+ end
696+ end
697+ else # multivariate syntax (see JuMP.@operator doc):
698+ if ! isnothing (hess)
699+ function (∇J_arg:: AbstractVector{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
700+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
701+ return ∇J_arg .= ∇J
702+ end
703+ else
704+ function (∇J_arg:: AbstractVector{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
705+ update_objective! (J, ∇J, Z̃_J, Z̃_arg)
706+ return ∇J_arg .= ∇J
707+ end
708+ end
709+ end
710+ ∇²J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
711+ function (Z̃_arg)
712+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
713+ return ∇²J[]
714+ end
715+ else # multivariate syntax (see JuMP.@operator doc):
716+ function (∇²J_arg:: AbstractMatrix{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
717+ update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
718+ return fill_lowertriangle! (∇²J_arg, ∇²J)
719+ end
720+ end
721+ if ! isnothing (hess)
722+ @operator (optim, J_op, nZ̃, J_func, ∇J_func!, ∇²J_func!)
723+ else
724+ @operator (optim, J_op, nZ̃, J_func, ∇J_func!)
725+ end
726+ return J_op
727+ end
728+
629729"""
630730 get_nonlincon_oracle(mpc::NonLinMPC, optim) -> g_oracle, geq_oracle
631731
@@ -661,7 +761,7 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
661761 Û0:: Vector{JNT} , X̂0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nX̂)
662762 gc:: Vector{JNT} , g:: Vector{JNT} = zeros (JNT, nc), zeros (JNT, ng)
663763 gi:: Vector{JNT} , geq:: Vector{JNT} = zeros (JNT, ngi), zeros (JNT, neq)
664- λi:: Vector{JNT} , λeq:: Vector{JNT} = ones (JNT, ngi), ones (JNT, neq)
764+ λi:: Vector{JNT} , λeq:: Vector{JNT} = ones (JNT, ngi), ones (JNT, neq)
665765 # -------------- inequality constraint: nonlinear oracle -----------------------------
666766 function gi! (gi, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
667767 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -725,7 +825,7 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
725825 jacobian_structure = ∇gi_structure,
726826 eval_jacobian = ∇gi_func!,
727827 hessian_lagrangian_structure = isnothing (hess) ? Tuple{Int,Int}[] : ∇²gi_structure,
728- eval_hessian_lagrangian = isnothing (hess) ? nothing : ∇²gi_func!
828+ eval_hessian_lagrangian = isnothing (hess) ? nothing : ∇²gi_func!
729829 )
730830 # ------------- equality constraints : nonlinear oracle ------------------------------
731831 function geq! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
@@ -792,130 +892,6 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
792892 return g_oracle, geq_oracle
793893end
794894
795- """
796- get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) -> J_op
797-
798- Return the nonlinear operator for the objective function of `mpc` [`NonLinMPC`](@ref).
799-
800- It is based on the splatting syntax. This method is really intricate and that's because of:
801-
802- - These functions are used inside the nonlinear optimization, so they must be type-stable
803- and as efficient as possible. All the function outputs and derivatives are cached and
804- updated in-place if required to use the efficient [`value_and_gradient!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
805- - The splatting syntax for objective functions implies the use of `Vararg{T,N}` (see the [performance tip](@extref Julia Be-aware-of-when-Julia-avoids-specializing))
806- and memoization to avoid redundant computations. This is already complex, but it's even
807- worse knowing that the automatic differentiation tools do not support splatting.
808- - The signature of gradient and hessian functions is not the same for univariate (`nZ̃ == 1`)
809- and multivariate (`nZ̃ > 1`) operators in `JuMP`. Both must be defined.
810- """
811- function get_nonlinobj_op (mpc:: NonLinMPC , optim:: JuMP.GenericModel{JNT} ) where JNT<: Real
812- model = mpc. estim. model
813- transcription = mpc. transcription
814- grad, hess = mpc. gradient, mpc. hessian
815- nu, ny, nx̂, nϵ = model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ
816- nk = get_nk (model, transcription)
817- Hp, Hc = mpc. Hp, mpc. Hc
818- ng = length (mpc. con. i_g)
819- nc, neq = mpc. con. nc, mpc. con. neq
820- nZ̃, nU, nŶ, nX̂, nK = length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂, Hp* nk
821- nΔŨ, nUe, nŶe = nu* Hc + nϵ, nU + nu, nŶ + ny
822- strict = Val (true )
823- myNaN = convert (JNT, NaN )
824- J:: Vector{JNT} = zeros (JNT, 1 )
825- ΔŨ:: Vector{JNT} = zeros (JNT, nΔŨ)
826- x̂0end:: Vector{JNT} = zeros (JNT, nx̂)
827- K0:: Vector{JNT} = zeros (JNT, nK)
828- Ue:: Vector{JNT} , Ŷe:: Vector{JNT} = zeros (JNT, nUe), zeros (JNT, nŶe)
829- U0:: Vector{JNT} , Ŷ0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nŶ)
830- Û0:: Vector{JNT} , X̂0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nX̂)
831- gc:: Vector{JNT} , g:: Vector{JNT} = zeros (JNT, nc), zeros (JNT, ng)
832- geq:: Vector{JNT} = zeros (JNT, neq)
833- function J! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
834- update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
835- return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
836- end
837- Z̃_J = fill (myNaN, nZ̃) # NaN to force update at first call
838- J_context = (
839- Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
840- Cache (Û0), Cache (K0), Cache (X̂0),
841- Cache (gc), Cache (g), Cache (geq),
842- )
843- ∇J_prep = prepare_gradient (J!, grad, Z̃_J, J_context... ; strict)
844- ∇J = Vector {JNT} (undef, nZ̃)
845- if ! isnothing (hess)
846- ∇²J_prep = prepare_hessian (J!, hess, Z̃_J, J_context... ; strict)
847- ∇²J = init_diffmat (JNT, hess, ∇²J_prep, nZ̃, nZ̃)
848- end
849- update_objective! = if ! isnothing (hess)
850- function (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
851- if isdifferent (Z̃_arg, Z̃_J)
852- Z̃_J .= Z̃_arg
853- J[], _ = value_gradient_and_hessian! (J!, ∇J, ∇²J, hess, Z̃_J, J_context... )
854- end
855- end
856- else
857- update_objective! = function (J, ∇J, Z̃_∇J, Z̃_arg)
858- if isdifferent (Z̃_arg, Z̃_∇J)
859- Z̃_∇J .= Z̃_arg
860- J[], _ = value_and_gradient! (J!, ∇J, ∇J_prep, grad, Z̃_∇J, J_context... )
861- end
862- end
863- end
864- J_func = if ! isnothing (hess)
865- function (Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
866- update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
867- return J[]:: T
868- end
869- else
870- function (Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
871- update_objective! (J, ∇J, Z̃_J, Z̃_arg)
872- return J[]:: T
873- end
874- end
875- ∇J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
876- if ! isnothing (hess)
877- function (Z̃_arg)
878- update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
879- return ∇J[]
880- end
881- else
882- function (Z̃_arg)
883- update_objective! (J, ∇J, Z̃_J, Z̃_arg)
884- return ∇J[]
885- end
886- end
887- else # multivariate syntax (see JuMP.@operator doc):
888- if ! isnothing (hess)
889- function (∇J_arg:: AbstractVector{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
890- update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
891- return ∇J_arg .= ∇J
892- end
893- else
894- function (∇J_arg:: AbstractVector{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
895- update_objective! (J, ∇J, Z̃_J, Z̃_arg)
896- return ∇J_arg .= ∇J
897- end
898- end
899- end
900- ∇²J_func! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
901- function (Z̃_arg)
902- update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
903- return ∇²J[]
904- end
905- else # multivariate syntax (see JuMP.@operator doc):
906- function (∇²J_arg:: AbstractMatrix{T} , Z̃_arg:: Vararg{T, N} ) where {N, T<: Real }
907- update_objective! (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
908- return fill_lowertriangle! (∇²J_arg, ∇²J)
909- end
910- end
911- if ! isnothing (hess)
912- @operator (optim, J_op, nZ̃, J_func, ∇J_func!, ∇²J_func!)
913- else
914- @operator (optim, J_op, nZ̃, J_func, ∇J_func!)
915- end
916- return J_op
917- end
918-
919895"""
920896 update_predictions!(
921897 ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq,
0 commit comments