Skip to content

Commit 9f4c357

Browse files
committed
added: NonLinMPC refactor for custom objective gradient
1 parent e19783e commit 9f4c357

File tree

2 files changed

+104
-50
lines changed

2 files changed

+104
-50
lines changed

src/controller/nonlinmpc.jl

Lines changed: 101 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -505,19 +505,19 @@ function init_optimization!(mpc::NonLinMPC, model::SimModel, optim)
505505
JuMP.set_attribute(optim, "nlp_scaling_max_gradient", 10.0/C)
506506
end
507507
end
508-
Jfunc, gfuncs, geqfuncs = get_optim_functions(mpc, optim)
509-
@operator(optim, J, nZ̃, Jfunc)
508+
Jfunc, ∇Jfunc!, gfuncs, geqfuncs, ∇geqfuncs! = get_optim_functions(mpc, optim)
509+
@operator(optim, J, nZ̃, Jfunc, ∇Jfunc!)
510510
@objective(optim, Min, J(Z̃var...))
511511
init_nonlincon!(mpc, model, transcription, gfuncs, geqfuncs)
512512
set_nonlincon!(mpc, model, optim)
513513
return nothing
514514
end
515515

516516
"""
517-
get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel) -> Jfunc, gfuncs, geqfuncs
517+
get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel) -> Jfunc, ∇Jfunc!, gfuncs, geqfuncs
518518
519-
Get the objective `Jfunc` function, and constraint `gfuncs` and `geqfuncs` function vectors
520-
for [`NonLinMPC`](@ref).
519+
Get the objective `Jfunc` function and `∇Jfunc!` to compute its gradient, and constraint
520+
`gfuncs` and `geqfuncs` function vectors for [`NonLinMPC`](@ref).
521521
522522
Inspired from: [User-defined operators with vector outputs](https://jump.dev/JuMP.jl/stable/tutorials/nonlinear/tips_and_tricks/#User-defined-operators-with-vector-outputs)
523523
"""
@@ -541,22 +541,25 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
541541
gc_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nc), Ncache)
542542
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), Ncache)
543543
geq_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, neq), Ncache)
544-
function update_simulations!(Z̃, Z̃tup::NTuple{N, T}) where {N, T<:Real}
545-
if any(new !== old for (new, old) in zip(Z̃tup, Z̃)) # new Z̃tup, update predictions:
546-
Z̃1 = Z̃tup[begin]
547-
for i in eachindex(Z̃tup)
548-
Z̃[i] = Z̃tup[i] # Z̃ .= Z̃tup seems to produce a type instability
544+
function update_simulations!(
545+
Z̃arg::Union{NTuple{N, T}, AbstractVector{T}}, Z̃cache
546+
) where {N, T<:Real}
547+
if any(cache !== arg for (cache, arg) in zip(Z̃cache, Z̃arg)) # new Z̃, update:
548+
for i in eachindex(Z̃cache)
549+
# Z̃cache .= Z̃arg is type unstable with Z̃arg::NTuple{N, FowardDiff.Dual}
550+
Z̃cache[i] = Z̃arg[i]
549551
end
552+
= Z̃cache
550553
ϵ = (nϵ 0) ? Z̃[end] : zero(T) # ϵ = 0 if nϵ == 0 (meaning no relaxation)
551-
ΔŨ = get_tmp(ΔŨ_cache, Z̃1)
552-
x̂0end = get_tmp(x̂0end_cache, Z̃1)
553-
Ue, Ŷe = get_tmp(Ue_cache, Z̃1), get_tmp(Ŷe_cache, Z̃1)
554-
U0, Ŷ0 = get_tmp(U0_cache, Z̃1), get_tmp(Ŷ0_cache, Z̃1)
555-
X̂0, Û0 = get_tmp(X̂0_cache, Z̃1), get_tmp(Û0_cache, Z̃1)
556-
gc, g = get_tmp(gc_cache, Z̃1), get_tmp(g_cache, Z̃1)
557-
geq = get_tmp(geq_cache, Z̃1)
554+
ΔŨ = get_tmp(ΔŨ_cache, T)
555+
x̂0end = get_tmp(x̂0end_cache, T)
556+
Ue, Ŷe = get_tmp(Ue_cache, T), get_tmp(Ŷe_cache, T)
557+
U0, Ŷ0 = get_tmp(U0_cache, T), get_tmp(Ŷ0_cache, T)
558+
X̂0, Û0 = get_tmp(X̂0_cache, T), get_tmp(Û0_cache, T)
559+
gc, g = get_tmp(gc_cache, T), get_tmp(g_cache, T)
560+
geq = get_tmp(geq_cache, T)
558561
U0 = getU0!(U0, mpc, Z̃)
559-
ΔŨ = getΔŨ!(ΔŨ, mpc, mpc.transcription, Z̃)
562+
ΔŨ = getΔŨ!(ΔŨ, mpc, transcription, Z̃)
560563
Ŷ0, x̂0end = predict!(Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
561564
Ue, Ŷe = extended_vectors!(Ue, Ŷe, mpc, U0, Ŷ0)
562565
gc = con_custom!(gc, mpc, Ue, Ŷe, ϵ)
@@ -565,43 +568,94 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
565568
end
566569
return nothing
567570
end
568-
function Jfunc(Z̃tup::Vararg{T, N}) where {N, T<:Real}
569-
Z̃1 = Z̃tup[begin]
570-
= get_tmp(Z̃_cache, Z̃1)
571-
update_simulations!(Z̃, Z̃tup)
572-
ΔŨ = get_tmp(ΔŨ_cache, Z̃1)
573-
Ue, Ŷe = get_tmp(Ue_cache, Z̃1), get_tmp(Ŷe_cache, Z̃1)
574-
U0, Ŷ0 = get_tmp(U0_cache, Z̃1), get_tmp(Ŷ0_cache, Z̃1)
575-
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)::T
576-
end
577-
function gfunc_i(i, Z̃tup::NTuple{N, T}) where {N, T<:Real}
578-
Z̃1 = Z̃tup[begin]
579-
= get_tmp(Z̃_cache, Z̃1)
580-
update_simulations!(Z̃, Z̃tup)
581-
g = get_tmp(g_cache, Z̃1)
582-
return g[i]::T
571+
# force specialization using Vararg, see https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing
572+
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
573+
return Jfunc(Z̃arg, get_tmp(Z̃_cache, T))::T
583574
end
584-
gfuncs = Vector{Function}(undef, ng)
585-
for i in 1:ng
586-
# this is another syntax for anonymous function, allowing parameters T and N:
587-
gfuncs[i] = function (Z̃tup::Vararg{T, N}) where {N, T<:Real}
588-
return gfunc_i(i, Z̃tup)
589-
end
575+
# method with the additional cache argument:
576+
function Jfunc(Z̃arg::Union{NTuple{N, T}, AbstractVector{T}}, Z̃cache) where {N, T<:Real}
577+
update_simulations!(Z̃arg, Z̃cache)
578+
ΔŨ = get_tmp(ΔŨ_cache, T)
579+
Ue, Ŷe = get_tmp(Ue_cache, T), get_tmp(Ŷe_cache, T)
580+
U0, Ŷ0 = get_tmp(U0_cache, T), get_tmp(Ŷ0_cache, T)
581+
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
590582
end
591-
function gfunceq_i(i, Z̃tup::NTuple{N, T}) where {N, T<:Real}
592-
Z̃1 = Z̃tup[begin]
593-
= get_tmp(Z̃_cache, Z̃1)
594-
update_simulations!(Z̃, Z̃tup)
595-
geq = get_tmp(geq_cache, Z̃1)
583+
Jfunc_vec(Z̃vec) = Jfunc(Z̃vec, get_tmp(Z̃_cache, Z̃vec[1]))
584+
Z̃vec = Vector{JNT}(undef, nZ̃)
585+
∇Jbuffer = GradientBuffer(Jfunc_vec, Z̃vec)
586+
function ∇Jfunc!(∇J, Z̃arg::Vararg{T, N}) where {N, T<:Real}
587+
Z̃vec .= Z̃arg
588+
gradient!(∇J, ∇Jbuffer, Z̃vec)
589+
return nothing
590+
end
591+
592+
593+
function gfunceq_i(i, Z̃arg::NTuple{N, T}) where {N, T<:Real}
594+
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
595+
geq = get_tmp(geq_cache, T)
596596
return geq[i]::T
597597
end
598598
geqfuncs = Vector{Function}(undef, neq)
599599
for i in 1:neq
600-
geqfuncs[i] = function (Z̃tup::Vararg{T, N}) where {N, T<:Real}
601-
return gfunceq_i(i, Z̃tup)
600+
geqfuncs[i] = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
601+
return gfunceq_i(i, Z̃arg)
602+
end
603+
end
604+
605+
606+
∇geqfuncs! = nothing
607+
#=
608+
function geqfunc!(geq, Z̃)
609+
update_simulations!(Z̃, get_tmp(Z̃_cache, T))
610+
geq = get_tmp(geq_cache, T)
611+
return
612+
end
613+
=#
614+
615+
#=
616+
617+
618+
619+
∇geq = Matrix{JNT}(undef, neq, nZ̃) # Jacobian of geq
620+
function ∇geqfunc_vec!(∇geq, Z̃vec)
621+
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
622+
return nothing
623+
end
624+
625+
626+
627+
628+
629+
function ∇geqfuncs_i!(∇geq_i, i, Z̃arg::NTuple{N, T}) where {N, T<:Real}
630+
Z̃arg_vec .= Z̃arg
631+
ForwardDiff
632+
633+
634+
end
635+
636+
∇geqfuncs! = Vector{Function}(undef, neq)
637+
for i in 1:neq
638+
∇eqfuncs![i] = function (∇geq, Z̃arg::Vararg{T, N}) where {N, T<:Real}
639+
return ∇geqfuncs_i!(∇geq, i, Z̃arg)
640+
end
641+
end
642+
=#
643+
644+
645+
# TODO:re-déplacer en haut:
646+
function gfunc_i(i, Z̃arg::NTuple{N, T}) where {N, T<:Real}
647+
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
648+
g = get_tmp(g_cache, T)
649+
return g[i]::T
650+
end
651+
gfuncs = Vector{Function}(undef, ng)
652+
for i in 1:ng
653+
# this is another syntax for anonymous function, allowing parameters T and N:
654+
gfuncs[i] = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
655+
return gfunc_i(i, Z̃arg)
602656
end
603657
end
604-
return Jfunc, gfuncs, geqfuncs
658+
return Jfunc, ∇Jfunc!, gfuncs, geqfuncs, ∇geqfuncs!
605659
end
606660

607661
"""

src/general.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ end
1515

1616
"Struct with both function and configuration for ForwardDiff gradient."
1717
struct GradientBuffer{FT<:Function, CT<:ForwardDiff.GradientConfig} <: DifferentiationBuffer
18-
f!::FT
18+
f::FT
1919
config::CT
2020
end
2121

22-
GradientBuffer(f!, x) = GradientBuffer(f!, ForwardDiff.GradientConfig(f!, x))
22+
GradientBuffer(f, x) = GradientBuffer(f, ForwardDiff.GradientConfig(f, x))
2323

2424
function gradient!(
2525
g, buffer::GradientBuffer, x
2626
)
27-
return ForwardDiff.gradient!(g, buffer.f!, x, buffer.config)
27+
return ForwardDiff.gradient!(g, buffer.f, x, buffer.config)
2828
end
2929

3030
"Struct with both function and configuration for ForwardDiff Jacobian."

0 commit comments

Comments
 (0)