@@ -630,45 +630,31 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
630630 end
631631 if ! isnothing (hess)
632632 prep_∇²J = prepare_hessian (Jfunc!, hess, Z̃_J, context_J... ; strict)
633+ @warn " Here's the objective Hessian sparsity pattern:"
633634 display (sparsity_pattern (prep_∇²J))
634635 else
635636 prep_∇²J = nothing
636637 end
637638 ∇J = Vector {JNT} (undef, nZ̃)
638639 ∇²J = init_diffmat (JNT, hess, prep_∇²J, nZ̃, nZ̃)
639-
640-
641-
642- function update_objective! (J, ∇J, Z̃, Z̃arg, hess:: Nothing , grad:: AbstractADType )
643- if isdifferent (Z̃arg, Z̃)
644- Z̃ .= Z̃arg
645- J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J... )
646- end
647- end
648- function update_objective! (J, ∇J, Z̃, Z̃arg, hess:: AbstractADType , grad:: Nothing )
649- if isdifferent (Z̃arg, Z̃)
650- Z̃ .= Z̃arg
651- J[], _ = value_gradient_and_hessian! (
652- Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃, context_J...
653- )
654- # display(∇J)
655- # display(∇²J)
656- # println(∇²J)
657- end
658- end
659-
660640 function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
661- update_objective! (J, ∇J, Z̃_J, Z̃arg, hess, grad)
641+ update_diff_objective! (
642+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
643+ )
662644 return J[]:: T
663645 end
664646 ∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
665647 function (Z̃arg)
666- update_objective! (J, ∇J, Z̃_J, Z̃arg, hess, grad)
648+ update_diff_objective! (
649+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
650+ )
667651 return ∇J[begin ]
668652 end
669653 else # multivariate syntax (see JuMP.@operator doc):
670654 function (∇Jarg:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
671- update_objective! (J, ∇J, Z̃_J, Z̃arg, hess, grad)
655+ update_diff_objective! (
656+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
657+ )
672658 return ∇Jarg .= ∇J
673659 end
674660 end
@@ -784,6 +770,52 @@ function update_predictions!(
784770 return nothing
785771end
786772
773+ """
774+ update_diff_objective!(
775+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J , context_J,
776+ grad::AbstractADType, hess::Nothing, Jfunc!, Z̃arg
777+ )
778+
779+ TBW
780+ """
781+ function update_diff_objective! (
782+ Z̃_J, J, ∇J, ∇²J, prep_∇J, _ , context_J,
783+ grad:: AbstractADType , hess:: Nothing , Jfunc!:: F , Z̃arg
784+ ) where F <: Function
785+ if isdifferent (Z̃arg, Z̃_J)
786+ Z̃_J .= Z̃arg
787+ J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context... )
788+ end
789+ return nothing
790+ end
791+
792+ function update_diff_objective! (
793+ Z̃_J, J, ∇J, ∇²J, _ , prep_∇²J, context_J,
794+ grad:: Nothing , hess:: AbstractADType , Jfunc!:: F , Z̃arg
795+ ) where F <: Function
796+ if isdifferent (Z̃arg, Z̃_J)
797+ Z̃_J .= Z̃arg
798+ J[], _ = value_gradient_and_hessian! (
799+ Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃_J, context_J...
800+ )
801+ @warn " Here's the current Hessian:"
802+ println (∇²J)
803+ end
804+ return nothing
805+ end
806+
807+ function update_diff_objective! (
808+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J,
809+ grad:: AbstractADType , hess:: AbstractADType , Jfunc!:: F , Z̃arg
810+ ) where F<: Function
811+ if isdifferent (Z̃arg, Z̃_J)
812+ Z̃_J .= Z̃arg # inefficient, as warned by validate_backends(), but still possible:
813+ hessian! (Jfunc!, ∇²J, prep_∇²J, hess, Z̃_J, context_J... )
814+ J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J... )
815+ end
816+ return nothing
817+ end
818+
787819@doc raw """
788820 con_custom!(gc, mpc::NonLinMPC, Ue, Ŷe, ϵ) -> gc
789821
0 commit comments