@@ -571,6 +571,7 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
571571 end
572572 # --------------------- cache for the AD functions -----------------------------------
573573 Z̃arg_vec = Vector {JNT} (undef, nZ̃)
574+ ∇J = Vector {JNT} (undef, nZ̃) # gradient of J
574575 g_vec = Vector {JNT} (undef, ng)
575576 ∇g = Matrix {JNT} (undef, ng, nZ̃) # Jacobian of g
576577 geq_vec = Vector {JNT} (undef, neq)
@@ -591,10 +592,19 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
591592 U0, Ŷ0 = get_tmp (U0_cache, T), get_tmp (Ŷ0_cache, T)
592593 return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
593594 end
594- ∇J_buffer = GradientBuffer (Jfunc_vec, Z̃arg_vec)
595- function ∇Jfunc! (∇J, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
596- Z̃arg_vec .= Z̃arg
597- return gradient! (∇J, ∇J_buffer, Z̃arg_vec)
595+ ∇J_buffer = GradientBuffer (Jfunc_vec, Z̃arg_vec)
596+ ∇Jfunc! = if nZ̃ == 1
597+ function (Z̃arg:: T ) where T<: Real
598+ Z̃arg_vec .= Z̃arg
599+ gradient! (∇J, ∇J_buffer, Z̃arg_vec)
600+ return ∇J[begin ] # univariate syntax, see JuMP.@operator doc
601+ end
602+ else
603+ function (∇J:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
604+ Z̃arg_vec .= Z̃arg
605+ gradient! (∇J, ∇J_buffer, Z̃arg_vec)
606+ return ∇J # multivariate syntax, see JuMP.@operator doc
607+ end
598608 end
599609 # --------------------- inequality constraint functions -------------------------------
600610 gfuncs = Vector {Function} (undef, ng)
@@ -614,11 +624,19 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
614624 ∇g_buffer = JacobianBuffer (gfunc_vec!, g_vec, Z̃arg_vec)
615625 ∇gfuncs! = Vector {Function} (undef, ng)
616626 for i in eachindex (∇gfuncs!)
617- ∇gfuncs![i] = function (∇g_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
618- Z̃arg_vec .= Z̃arg
619- jacobian! (∇g, ∇g_buffer, g_vec, Z̃arg_vec)
620- ∇g_i .= @views ∇g[i, :]
621- return ∇g_i
627+ ∇gfuncs![i] = if nZ̃ == 1
628+ function (Z̃arg:: T ) where T<: Real
629+ Z̃arg_vec .= Z̃arg
630+ jacobian! (∇g, ∇g_buffer, g_vec, Z̃arg_vec)
631+ return ∇g[i, begin ] # univariate syntax, see JuMP.@operator doc
632+ end
633+ else
634+ function (∇g_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
635+ Z̃arg_vec .= Z̃arg
636+ jacobian! (∇g, ∇g_buffer, g_vec, Z̃arg_vec)
637+ ∇g_i .= @views ∇g[i, :]
638+ return ∇g_i # multivariate syntax, see JuMP.@operator doc
639+ end
622640 end
623641 end
624642 # --------------------- equality constraint functions ---------------------------------
@@ -639,11 +657,19 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
639657 ∇geq_buffer = JacobianBuffer (geqfunc_vec!, geq_vec, Z̃arg_vec)
640658 ∇geqfuncs! = Vector {Function} (undef, neq)
641659 for i in eachindex (∇geqfuncs!)
642- ∇geqfuncs![i] = function (∇geq_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
643- Z̃arg_vec .= Z̃arg
644- jacobian! (∇geq, ∇geq_buffer, geq_vec, Z̃arg_vec)
645- ∇geq_i .= @views ∇geq[i, :]
646- return ∇geq_i
660+ ∇geqfuncs![i] = if nZ̃ == 1
661+ function (Z̃arg:: T ) where T<: Real
662+ Z̃arg_vec .= Z̃arg
663+ jacobian! (∇geq, ∇geq_buffer, geq_vec, Z̃arg_vec)
664+ return ∇geq[i, begin ] # univariate syntax, see JuMP.@operator doc
665+ end
666+ else
667+ function (∇geq_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
668+ Z̃arg_vec .= Z̃arg
669+ jacobian! (∇geq, ∇geq_buffer, geq_vec, Z̃arg_vec)
670+ ∇geq_i .= @views ∇geq[i, :]
671+ return ∇geq_i # multivariate syntax, see JuMP.@operator doc
672+ end
647673 end
648674 end
649675 return Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
0 commit comments