@@ -577,7 +577,8 @@ return vectors with the nonlinear equality constraint functions `geqfuncs` and g
577577This method is really intricate and I'm not proud of it. That's because of 3 elements:
578578
579579- These functions are used inside the nonlinear optimization, so they must be type-stable
580- and as efficient as possible.
580+ and as efficient as possible. All the function outputs and derivatives are cached and
581+ updated in-place if required to use the efficient [`value_and_jacobian!`](@extref DifferentiationInterface DifferentiationInterface.value_and_jacobian!).
581582- The `JuMP` NLP syntax forces splatting for the decision variable, which implies use
582583 of `Vararg{T,N}` (see the [performance tip][@extref Julia Be-aware-of-when-Julia-avoids-specializing]
583584 ) and memoization to avoid redundant computations. This is already complex, but it's even
@@ -588,16 +589,17 @@ This method is really intricate and I'm not proud of it. That's because of 3 ele
588589Inspired from: [User-defined operators with vector outputs](@extref JuMP User-defined-operators-with-vector-outputs)
589590"""
590591function get_optim_functions (mpc:: NonLinMPC , :: JuMP.GenericModel{JNT} ) where JNT<: Real
591- # ----- common cache for Jfunc, gfuncs, geqfuncs called with floats -------------------
592+ # ----------- common cache for Jfunc, gfuncs and geqfuncs --------- -------------------
592593 model = mpc. estim. model
594+ grad, jac = mpc. gradient, mpc. jacobian
593595 nu, ny, nx̂, nϵ, nk = model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ, model. nk
594596 Hp, Hc = mpc. Hp, mpc. Hc
595597 ng, nc, neq = length (mpc. con. i_g), mpc. con. nc, mpc. con. neq
596598 nZ̃, nU, nŶ, nX̂, nK = length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂, Hp* nk
597599 nΔŨ, nUe, nŶe = nu* Hc + nϵ, nU + nu, nŶ + ny
598600 strict = Val (true )
599- myNaN = convert (JNT, NaN ) # NaN to force update_simulations! at first call:
600- Z̃ :: Vector{JNT} = fill (myNaN, nZ̃ )
601+ myNaN = convert (JNT, NaN )
602+ J :: Vector{JNT} = zeros (JNT, 1 )
601603 ΔŨ:: Vector{JNT} = zeros (JNT, nΔŨ)
602604 x̂0end:: Vector{JNT} = zeros (JNT, nx̂)
603605 K0:: Vector{JNT} = zeros (JNT, nK)
@@ -607,129 +609,119 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
607609 gc:: Vector{JNT} , g:: Vector{JNT} = zeros (JNT, nc), zeros (JNT, ng)
608610 geq:: Vector{JNT} = zeros (JNT, neq)
609611 # ---------------------- objective function -------------------------------------------
610- function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
611- if isdifferent (Z̃arg, Z̃)
612- Z̃ .= Z̃arg
613- update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
614- end
615- return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ):: T
616- end
617612 function Jfunc! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq)
618613 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
619614 return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
620615 end
621- Z̃_∇J = fill (myNaN, nZ̃)
616+ Z̃_∇J = fill (myNaN, nZ̃) # NaN to force update_predictions! at first call
622617 ∇J_context = (
623618 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
624619 Cache (Û0), Cache (K0), Cache (X̂0),
625620 Cache (gc), Cache (g), Cache (geq),
626621 )
627- ∇J_prep = prepare_gradient (Jfunc!, mpc . gradient , Z̃_∇J, ∇J_context... ; strict)
622+ ∇J_prep = prepare_gradient (Jfunc!, grad , Z̃_∇J, ∇J_context... ; strict)
628623 ∇J = Vector {JNT} (undef, nZ̃)
629- ∇Jfunc! = if nZ̃ == 1
624+ function update_objective! (J, ∇J, Z̃, Z̃arg)
625+ if isdifferent (Z̃arg, Z̃)
626+ Z̃ .= Z̃arg
627+ J[], _ = value_and_gradient! (Jfunc!, ∇J, ∇J_prep, grad, Z̃_∇J, ∇J_context... )
628+ end
629+ end
630+ function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
631+ update_objective! (J, ∇J, Z̃_∇J, Z̃arg)
632+ return J[]:: T
633+ end
634+ ∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
630635 function (Z̃arg)
631- Z̃_∇J .= Z̃arg
632- gradient! (Jfunc!, ∇J, ∇J_prep, mpc. gradient, Z̃_∇J, ∇J_context... )
633- return ∇J[begin ] # univariate syntax, see JuMP.@operator doc
636+ update_objective! (J, ∇J, Z̃_∇J, Z̃arg)
637+ return ∇J[begin ]
634638 end
635- else
636- function (∇J:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
637- Z̃_∇J .= Z̃arg
638- gradient! (Jfunc!, ∇J, ∇J_prep, mpc. gradient, Z̃_∇J, ∇J_context... )
639- return ∇J # multivariate syntax, see JuMP.@operator doc
639+ else # multivariate syntax (see JuMP.@operator doc):
640+ function (∇Jarg:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
641+ update_objective! (J, ∇J, Z̃_∇J, Z̃arg)
642+ return ∇Jarg .= ∇J
640643 end
641644 end
642645 ∇²Jfunc! = nothing
643646 # --------------------- inequality constraint functions -------------------------------
644- gfuncs = Vector {Function} (undef, ng)
645- for i in eachindex (gfuncs)
646- gfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
647- if isdifferent (Z̃arg, Z̃)
648- Z̃ .= Z̃arg
649- update_predictions! (
650- ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃
651- )
652- end
653- return g[i]:: T
654- end
655- gfuncs[i] = gfunc_i
656- end
657647 function gfunc! (g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq)
658- return update_predictions! (
659- ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃
660- )
648+ update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
649+ return g
661650 end
662- Z̃_∇g = fill (myNaN, nZ̃)
651+ Z̃_∇g = fill (myNaN, nZ̃) # NaN to force update_predictions! at first call
663652 ∇g_context = (
664653 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
665- Cache (Û0), Cache (K0), Cache (X̂0),
654+ Cache (Û0), Cache (K0), Cache (X̂0),
666655 Cache (gc), Cache (geq),
667656 )
668657 # temporarily enable all the inequality constraints for sparsity detection:
669658 mpc. con. i_g[1 : end - nc] .= true
670- ∇g_prep = prepare_jacobian (gfunc!, g, mpc . jacobian , Z̃_∇g, ∇g_context... ; strict)
659+ ∇g_prep = prepare_jacobian (gfunc!, g, jac , Z̃_∇g, ∇g_context... ; strict)
671660 mpc. con. i_g[1 : end - nc] .= false
672- ∇g = init_diffmat (JNT, mpc. jacobian, ∇g_prep, nZ̃, ng)
661+ ∇g = init_diffmat (JNT, jac, ∇g_prep, nZ̃, ng)
662+ function update_con! (g, ∇g, Z̃, Z̃arg)
663+ if isdifferent (Z̃arg, Z̃)
664+ Z̃ .= Z̃arg
665+ value_and_jacobian! (gfunc!, g, ∇g, ∇g_prep, jac, Z̃, ∇g_context... )
666+ end
667+ end
668+ gfuncs = Vector {Function} (undef, ng)
669+ for i in eachindex (gfuncs)
670+ gfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
671+ update_con! (g, ∇g, Z̃_∇g, Z̃arg)
672+ return g[i]:: T
673+ end
674+ gfuncs[i] = gfunc_i
675+ end
673676 ∇gfuncs! = Vector {Function} (undef, ng)
674677 for i in eachindex (∇gfuncs!)
675- ∇gfuncs_i! = if nZ̃ == 1
678+ ∇gfuncs_i! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
676679 function (Z̃arg:: T ) where T<: Real
677- if isdifferent (Z̃arg, Z̃_∇g)
678- Z̃_∇g .= Z̃arg
679- jacobian! (gfunc!, g, ∇g, ∇g_prep, mpc. jacobian, Z̃_∇g, ∇g_context... )
680- end
681- return ∇g[i, begin ] # univariate syntax, see JuMP.@operator doc
680+ update_con! (g, ∇g, Z̃_∇g, Z̃arg)
681+ return ∇g[i, begin ]
682682 end
683- else
683+ else # multivariate syntax (see JuMP.@operator doc):
684684 function (∇g_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
685- if isdifferent (Z̃arg, Z̃_∇g)
686- Z̃_∇g .= Z̃arg
687- jacobian! (gfunc!, g, ∇g, ∇g_prep, mpc. jacobian, Z̃_∇g, ∇g_context... )
688- end
689- return ∇g_i .= @views ∇g[i, :] # multivariate syntax, see JuMP.@operator doc
685+ update_con! (g, ∇g, Z̃_∇g, Z̃arg)
686+ return ∇g_i .= @views ∇g[i, :]
690687 end
691688 end
692689 ∇gfuncs![i] = ∇gfuncs_i!
693690 end
694691 # --------------------- equality constraint functions ---------------------------------
695- geqfuncs = Vector {Function} (undef, neq)
696- for i in eachindex (geqfuncs)
697- geqfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
698- if isdifferent (Z̃arg, Z̃)
699- Z̃ .= Z̃arg
700- update_predictions! (
701- ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃
702- )
703- end
704- return geq[i]:: T
705- end
706- geqfuncs[i] = geqfunc_i
707- end
708692 function geqfunc! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
709- return update_predictions! (
710- ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃
711- )
693+ update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
694+ return geq
712695 end
713- Z̃_∇geq = fill (myNaN, nZ̃)
696+ Z̃_∇geq = fill (myNaN, nZ̃) # NaN to force update_predictions! at first call
714697 ∇geq_context = (
715698 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
716699 Cache (Û0), Cache (K0), Cache (X̂0),
717700 Cache (gc), Cache (g)
718701 )
719- ∇geq_prep = prepare_jacobian (geqfunc!, geq, mpc. jacobian, Z̃_∇geq, ∇geq_context... ; strict)
720- ∇geq = init_diffmat (JNT, mpc. jacobian, ∇geq_prep, nZ̃, neq)
702+ ∇geq_prep = prepare_jacobian (geqfunc!, geq, jac, Z̃_∇geq, ∇geq_context... ; strict)
703+ ∇geq = init_diffmat (JNT, jac, ∇geq_prep, nZ̃, neq)
704+ function update_con_eq! (geq, ∇geq, Z̃, Z̃arg)
705+ if isdifferent (Z̃arg, Z̃)
706+ Z̃ .= Z̃arg
707+ value_and_jacobian! (geqfunc!, geq, ∇geq, ∇geq_prep, jac, Z̃, ∇geq_context... )
708+ end
709+ end
710+ geqfuncs = Vector {Function} (undef, neq)
711+ for i in eachindex (geqfuncs)
712+ geqfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
713+ update_con_eq! (geq, ∇geq, Z̃_∇geq, Z̃arg)
714+ return geq[i]:: T
715+ end
716+ geqfuncs[i] = geqfunc_i
717+ end
721718 ∇geqfuncs! = Vector {Function} (undef, neq)
722719 for i in eachindex (∇geqfuncs!)
723720 # only multivariate syntax, univariate is impossible since nonlinear equality
724721 # constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
725722 ∇geqfuncs_i! =
726723 function (∇geq_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
727- if isdifferent (Z̃arg, Z̃_∇geq)
728- Z̃_∇geq .= Z̃arg
729- jacobian! (
730- geqfunc!, geq, ∇geq, ∇geq_prep, mpc. jacobian, Z̃_∇geq, ∇geq_context...
731- )
732- end
724+ update_con_eq! (geq, ∇geq, Z̃_∇geq, Z̃arg)
733725 return ∇geq_i .= @views ∇geq[i, :]
734726 end
735727 ∇geqfuncs![i] = ∇geqfuncs_i!
0 commit comments