Skip to content

Commit 9c1c4c0

Browse files
committed
debug: handle univariate syntax special case
`JuMP.operator` expects a different signature for the gradient when the function is univariate.
1 parent 3d4f917 commit 9c1c4c0

File tree

1 file changed

+40
-14
lines changed

1 file changed

+40
-14
lines changed

src/controller/nonlinmpc.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
571571
end
572572
# --------------------- cache for the AD functions -----------------------------------
573573
Z̃arg_vec = Vector{JNT}(undef, nZ̃)
574+
∇J = Vector{JNT}(undef, nZ̃) # gradient of J
574575
g_vec = Vector{JNT}(undef, ng)
575576
∇g = Matrix{JNT}(undef, ng, nZ̃) # Jacobian of g
576577
geq_vec = Vector{JNT}(undef, neq)
@@ -591,10 +592,19 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
591592
U0, Ŷ0 = get_tmp(U0_cache, T), get_tmp(Ŷ0_cache, T)
592593
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
593594
end
594-
∇J_buffer = GradientBuffer(Jfunc_vec, Z̃arg_vec)
595-
function ∇Jfunc!(∇J, Z̃arg::Vararg{T, N}) where {N, T<:Real}
596-
Z̃arg_vec .= Z̃arg
597-
return gradient!(∇J, ∇J_buffer, Z̃arg_vec)
595+
∇J_buffer = GradientBuffer(Jfunc_vec, Z̃arg_vec)
596+
∇Jfunc! = if nZ̃ == 1
597+
function (Z̃arg::T) where T<:Real
598+
Z̃arg_vec .= Z̃arg
599+
gradient!(∇J, ∇J_buffer, Z̃arg_vec)
600+
return ∇J[begin] # univariate syntax, see JuMP.@operator doc
601+
end
602+
else
603+
function (∇J::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
604+
Z̃arg_vec .= Z̃arg
605+
gradient!(∇J, ∇J_buffer, Z̃arg_vec)
606+
return ∇J # multivariate syntax, see JuMP.@operator doc
607+
end
598608
end
599609
# --------------------- inequality constraint functions -------------------------------
600610
gfuncs = Vector{Function}(undef, ng)
@@ -614,11 +624,19 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
614624
∇g_buffer = JacobianBuffer(gfunc_vec!, g_vec, Z̃arg_vec)
615625
∇gfuncs! = Vector{Function}(undef, ng)
616626
for i in eachindex(∇gfuncs!)
617-
∇gfuncs![i] = function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
618-
Z̃arg_vec .= Z̃arg
619-
jacobian!(∇g, ∇g_buffer, g_vec, Z̃arg_vec)
620-
∇g_i .= @views ∇g[i, :]
621-
return ∇g_i
627+
∇gfuncs![i] = if nZ̃ == 1
628+
function (Z̃arg::T) where T<:Real
629+
Z̃arg_vec .= Z̃arg
630+
jacobian!(∇g, ∇g_buffer, g_vec, Z̃arg_vec)
631+
return ∇g[i, begin] # univariate syntax, see JuMP.@operator doc
632+
end
633+
else
634+
function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
635+
Z̃arg_vec .= Z̃arg
636+
jacobian!(∇g, ∇g_buffer, g_vec, Z̃arg_vec)
637+
∇g_i .= @views ∇g[i, :]
638+
return ∇g_i # multivariate syntax, see JuMP.@operator doc
639+
end
622640
end
623641
end
624642
# --------------------- equality constraint functions ---------------------------------
@@ -639,11 +657,19 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
639657
∇geq_buffer = JacobianBuffer(geqfunc_vec!, geq_vec, Z̃arg_vec)
640658
∇geqfuncs! = Vector{Function}(undef, neq)
641659
for i in eachindex(∇geqfuncs!)
642-
∇geqfuncs![i] = function (∇geq_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
643-
Z̃arg_vec .= Z̃arg
644-
jacobian!(∇geq, ∇geq_buffer, geq_vec, Z̃arg_vec)
645-
∇geq_i .= @views ∇geq[i, :]
646-
return ∇geq_i
660+
∇geqfuncs![i] = if nZ̃ == 1
661+
function (Z̃arg::T) where T<:Real
662+
Z̃arg_vec .= Z̃arg
663+
jacobian!(∇geq, ∇geq_buffer, geq_vec, Z̃arg_vec)
664+
return ∇geq[i, begin] # univariate syntax, see JuMP.@operator doc
665+
end
666+
else
667+
function (∇geq_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
668+
Z̃arg_vec .= Z̃arg
669+
jacobian!(∇geq, ∇geq_buffer, geq_vec, Z̃arg_vec)
670+
∇geq_i .= @views ∇geq[i, :]
671+
return ∇geq_i # multivariate syntax, see JuMP.@operator doc
672+
end
647673
end
648674
end
649675
return Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!

0 commit comments

Comments
 (0)