@@ -554,29 +554,27 @@ function init_optimization!(mpc::NonLinMPC, model::SimModel, optim::JuMP.Generic
554554 end
555555 end
556556 validate_backends (mpc. gradient, mpc. hessian)
557- Jfunc, ∇Jfunc!, ∇²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs! = get_optim_functions (
558- mpc, optim
559- )
560- Jargs = isnothing (∇²Jfunc!) ? (Jfunc, ∇Jfunc!) : (Jfunc, ∇Jfunc!, ∇²Jfunc!)
561- @operator (optim, J, nZ̃, Jargs... )
557+ J_args, g_vec_args, geq_vec_args = get_optim_functions (mpc, optim)
558+ # display(J_args)
559+ @operator (optim, J, nZ̃, J_args... )
562560 @objective (optim, Min, J (Z̃var... ))
563- init_nonlincon! (mpc, model, transcription, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs! )
561+ init_nonlincon! (mpc, model, transcription, g_vec_args, geq_vec_args )
564562 set_nonlincon! (mpc, model, transcription, optim)
565563 return nothing
566564end
567565
568566"""
569567 get_optim_functions(
570568 mpc::NonLinMPC, optim::JuMP.GenericModel
571- ) -> Jfunc, ∇Jfunc!, ∇J²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
569+ ) -> J_args, g_vec_args, geq_vec_args
572570
573571Return the functions for the nonlinear optimization of `mpc` [`NonLinMPC`](@ref) controller.
574-
575- Return the nonlinear objective `Jfunc` function, and `∇Jfunc!` and `∇²Jfunc!`, to compute
576- its gradient and hessian, respectively . Also return vectors with the nonlinear inequality
577- constraint functions `gfuncs`, and `∇gfuncs!`, for the associated gradients. Lastly, also
578- return vectors with the nonlinear equality constraint functions `geqfuncs` and gradients
579- `∇geqfuncs!` .
572+
573+ Return the tuple `J_args` containing the functions to compute the objective function
574+ value and its derivatives . Also return the tuple `g_vec_args` containing 2 vectors of
575+ functions to compute the nonlinear inequality values and associated gradients. Lastly, also
576+ return `geq_vec_args` containing 2 vectors of functions to compute the nonlinear equality
577+ values and associated gradients .
580578
581579This method is really intricate and I'm not proud of it. That's because of 3 elements:
582580
@@ -630,35 +628,53 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
630628 end
631629 if ! isnothing (hess)
632630 prep_∇²J = prepare_hessian (Jfunc!, hess, Z̃_J, context_J... ; strict)
633- @warn " Here's the objective Hessian sparsity pattern:"
634631 display (sparsity_pattern (prep_∇²J))
635632 else
636633 prep_∇²J = nothing
637634 end
638635 ∇J = Vector {JNT} (undef, nZ̃)
639636 ∇²J = init_diffmat (JNT, hess, prep_∇²J, nZ̃, nZ̃)
640637 function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
641- update_diff_objective ! (
638+ update_memoized_diff ! (
642639 Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
643640 )
644641 return J[]:: T
645642 end
646- ∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
643+ ∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
647644 function (Z̃arg)
648- update_diff_objective ! (
645+ update_memoized_diff ! (
649646 Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
650647 )
651648 return ∇J[begin ]
652649 end
653- else # multivariate syntax (see JuMP.@operator doc):
650+ else # multivariate syntax (see JuMP.@operator doc):
654651 function (∇Jarg:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
655- update_diff_objective ! (
652+ update_memoized_diff ! (
656653 Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
657654 )
658655 return ∇Jarg .= ∇J
659656 end
660657 end
661- ∇²Jfunc! = nothing
658+ ∇²Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
659+ function (Z̃arg)
660+ update_memoized_diff! (
661+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
662+ )
663+ return ∇²J[begin , begin ]
664+ end
665+ else # multivariate syntax (see JuMP.@operator doc):
666+ function (∇²Jarg:: AbstractMatrix{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
667+ print (" d" )
668+ update_memoized_diff! (
669+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
670+ )
671+ for i in 1 : N, j in 1 : i
672+ ∇²Jarg[i, j] = ∇²J[i, j]
673+ end
674+ return ∇²Jarg
675+ end
676+ end
677+ J_args = isnothing (hess) ? (Jfunc, ∇Jfunc!) : (Jfunc, ∇Jfunc!, ∇²Jfunc!)
662678 # --------------------- inequality constraint functions -------------------------------
663679 function gfunc! (g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq)
664680 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -672,19 +688,13 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
672688 )
673689 # temporarily enable all the inequality constraints for sparsity detection:
674690 mpc. con. i_g[1 : end - nc] .= true
675- ∇g_prep = prepare_jacobian (gfunc!, g, jac, Z̃_g, context_g... ; strict)
691+ prep_∇g = prepare_jacobian (gfunc!, g, jac, Z̃_g, context_g... ; strict)
676692 mpc. con. i_g[1 : end - nc] .= false
677- ∇g = init_diffmat (JNT, jac, ∇g_prep, nZ̃, ng)
678- function update_con! (g, ∇g, Z̃, Z̃arg)
679- if isdifferent (Z̃arg, Z̃)
680- Z̃ .= Z̃arg
681- value_and_jacobian! (gfunc!, g, ∇g, ∇g_prep, jac, Z̃, context_g... )
682- end
683- end
693+ ∇g = init_diffmat (JNT, jac, prep_∇g, nZ̃, ng)
684694 gfuncs = Vector {Function} (undef, ng)
685695 for i in eachindex (gfuncs)
686696 gfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
687- update_con! ( g, ∇g, Z̃_g , Z̃arg)
697+ update_memoized_diff! (Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc! , Z̃arg)
688698 return g[i]:: T
689699 end
690700 gfuncs[i] = gfunc_i
@@ -693,17 +703,18 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
693703 for i in eachindex (∇gfuncs!)
694704 ∇gfuncs_i! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
695705 function (Z̃arg:: T ) where T<: Real
696- update_con! ( g, ∇g, Z̃_g , Z̃arg)
706+ update_memoized_diff! (Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc! , Z̃arg)
697707 return ∇g[i, begin ]
698708 end
699- else # multivariate syntax (see JuMP.@operator doc):
709+ else # multivariate syntax (see JuMP.@operator doc):
700710 function (∇g_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
701- update_con! ( g, ∇g, Z̃_g , Z̃arg)
711+ update_memoized_diff! (Z̃_g, g, ∇g, prep_∇g, context_g, jac, gfunc! , Z̃arg)
702712 return ∇g_i .= @views ∇g[i, :]
703713 end
704714 end
705715 ∇gfuncs![i] = ∇gfuncs_i!
706716 end
717+ g_vec_args = (gfuncs, ∇gfuncs!)
707718 # --------------------- equality constraint functions ---------------------------------
708719 function geqfunc! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
709720 update_predictions! (ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
@@ -715,18 +726,14 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
715726 Cache (Û0), Cache (K0), Cache (X̂0),
716727 Cache (gc), Cache (g)
717728 )
718- ∇geq_prep = prepare_jacobian (geqfunc!, geq, jac, Z̃_geq, context_geq... ; strict)
719- ∇geq = init_diffmat (JNT, jac, ∇geq_prep, nZ̃, neq)
720- function update_con_eq! (geq, ∇geq, Z̃, Z̃arg)
721- if isdifferent (Z̃arg, Z̃)
722- Z̃ .= Z̃arg
723- value_and_jacobian! (geqfunc!, geq, ∇geq, ∇geq_prep, jac, Z̃, context_geq... )
724- end
725- end
729+ prep_∇geq = prepare_jacobian (geqfunc!, geq, jac, Z̃_geq, context_geq... ; strict)
730+ ∇geq = init_diffmat (JNT, jac, prep_∇geq, nZ̃, neq)
726731 geqfuncs = Vector {Function} (undef, neq)
727732 for i in eachindex (geqfuncs)
728733 geqfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
729- update_con_eq! (geq, ∇geq, Z̃_geq, Z̃arg)
734+ update_memoized_diff! (
735+ Z̃_geq, geq, ∇geq, prep_∇geq, context_geq, jac, geqfunc!, Z̃arg
736+ )
730737 return geq[i]:: T
731738 end
732739 geqfuncs[i] = geqfunc_i
@@ -737,12 +744,15 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
737744 # constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
738745 ∇geqfuncs_i! =
739746 function (∇geq_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
740- update_con_eq! (geq, ∇geq, Z̃_geq, Z̃arg)
747+ update_memoized_diff! (
748+ Z̃_geq, geq, ∇geq, prep_∇geq, context_geq, jac, geqfunc!, Z̃arg
749+ )
741750 return ∇geq_i .= @views ∇geq[i, :]
742751 end
743752 ∇geqfuncs![i] = ∇geqfuncs_i!
744753 end
745- return Jfunc, ∇Jfunc!, ∇²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
754+ geq_vec_args = (geqfuncs, ∇geqfuncs!)
755+ return J_args, g_vec_args, geq_vec_args
746756end
747757
748758"""
@@ -770,52 +780,6 @@ function update_predictions!(
770780 return nothing
771781end
772782
773- """
774- update_diff_objective!(
775- Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J , context_J,
776- grad::AbstractADType, hess::Nothing, Jfunc!, Z̃arg
777- )
778-
779- TBW
780- """
781- function update_diff_objective! (
782- Z̃_J, J, ∇J, ∇²J, prep_∇J, _ , context_J,
783- grad:: AbstractADType , hess:: Nothing , Jfunc!:: F , Z̃arg
784- ) where F <: Function
785- if isdifferent (Z̃arg, Z̃_J)
786- Z̃_J .= Z̃arg
787- J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context... )
788- end
789- return nothing
790- end
791-
792- function update_diff_objective! (
793- Z̃_J, J, ∇J, ∇²J, _ , prep_∇²J, context_J,
794- grad:: Nothing , hess:: AbstractADType , Jfunc!:: F , Z̃arg
795- ) where F <: Function
796- if isdifferent (Z̃arg, Z̃_J)
797- Z̃_J .= Z̃arg
798- J[], _ = value_gradient_and_hessian! (
799- Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃_J, context_J...
800- )
801- @warn " Uncomment the following line to print the current Hessian"
802- # println(∇²J)
803- end
804- return nothing
805- end
806-
807- function update_diff_objective! (
808- Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J,
809- grad:: AbstractADType , hess:: AbstractADType , Jfunc!:: F , Z̃arg
810- ) where F<: Function
811- if isdifferent (Z̃arg, Z̃_J)
812- Z̃_J .= Z̃arg # inefficient, as warned by validate_backends(), but still possible:
813- hessian! (Jfunc!, ∇²J, prep_∇²J, hess, Z̃_J, context_J... )
814- J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J... )
815- end
816- return nothing
817- end
818-
819783@doc raw """
820784 con_custom!(gc, mpc::NonLinMPC, Ue, Ŷe, ϵ) -> gc
821785
0 commit comments