Skip to content

Commit 6887c62

Browse files
committed
WIP: using Cache instead of DiffCache
1 parent 56433fb commit 6887c62

File tree

2 files changed

+114
-66
lines changed

2 files changed

+114
-66
lines changed

src/ModelPredictiveControl.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using ProgressLogging
99

1010
using DifferentiationInterface: ADTypes.AbstractADType, AutoForwardDiff, AutoSparse
1111
using DifferentiationInterface: gradient!, jacobian!, prepare_gradient, prepare_jacobian
12+
using DifferentiationInterface: Constant, Cache
1213
using SparseConnectivityTracer: TracerSparsityDetector
1314
using SparseMatrixColorings: GreedyColoringAlgorithm
1415

src/controller/nonlinmpc.jl

Lines changed: 113 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -582,43 +582,14 @@ function get_optim_functions(
582582
jac_backend ::AbstractADType
583583
) where JNT<:Real
584584
model, transcription = mpc.estim.model, mpc.transcription
585-
nu, ny, nx̂, nϵ, Hp, Hc = model.nu, model.ny, mpc.estim.nx̂, mpc.nϵ, mpc.Hp, mpc.Hc
586-
ng, nc, neq = length(mpc.con.i_g), mpc.con.nc, mpc.con.neq
587-
nZ̃, nU, nŶ, nX̂ = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂
588-
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
589-
Ncache = nZ̃ + 3
590-
myNaN = convert(JNT, NaN) # fill Z̃ with NaNs to force update_simulations! at 1st call:
591-
# ---------------------- differentiation cache ---------------------------------------
592-
Z̃_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(fill(myNaN, nZ̃), Ncache)
593-
ΔŨ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nΔŨ), Ncache)
594-
x̂0end_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Ncache)
595-
Ŷe_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶe), Ncache)
596-
Ue_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nUe), Ncache)
597-
Ŷ0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), Ncache)
598-
U0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Ncache)
599-
Û0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), Ncache)
600-
X̂0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nX̂), Ncache)
601-
gc_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nc), Ncache)
602-
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), Ncache)
603-
geq_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, neq), Ncache)
604585
# --------------------- update simulation function ------------------------------------
605-
function update_simulations!(
606-
Z̃arg::Union{NTuple{N, T}, AbstractVector{T}}, Z̃cache
607-
) where {N, T<:Real}
608-
if isdifferent(Z̃cache, Z̃arg)
609-
for i in eachindex(Z̃cache)
586+
function update_simulations!(Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
587+
if isdifferent(Z̃arg, Z̃)
588+
for i in eachindex(Z̃)
610589
# Z̃cache .= Z̃arg is type unstable with Z̃arg::NTuple{N, FowardDiff.Dual}
611-
Z̃cache[i] = Z̃arg[i]
590+
[i] = Z̃arg[i]
612591
end
613-
= Z̃cache
614592
ϵ = (nϵ 0) ? Z̃[end] : zero(T) # ϵ = 0 if nϵ == 0 (meaning no relaxation)
615-
ΔŨ = get_tmp(ΔŨ_cache, T)
616-
x̂0end = get_tmp(x̂0end_cache, T)
617-
Ue, Ŷe = get_tmp(Ue_cache, T), get_tmp(Ŷe_cache, T)
618-
U0, Ŷ0 = get_tmp(U0_cache, T), get_tmp(Ŷ0_cache, T)
619-
X̂0, Û0 = get_tmp(X̂0_cache, T), get_tmp(Û0_cache, T)
620-
gc, g = get_tmp(gc_cache, T), get_tmp(g_cache, T)
621-
geq = get_tmp(geq_cache, T)
622593
U0 = getU0!(U0, mpc, Z̃)
623594
ΔŨ = getΔŨ!(ΔŨ, mpc, transcription, Z̃)
624595
Ŷ0, x̂0end = predict!(Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
@@ -629,78 +600,151 @@ function get_optim_functions(
629600
end
630601
return nothing
631602
end
632-
# --------------------- objective functions -------------------------------------------
603+
# ---------------------- JNT vectors cache --------------------------------------------
604+
nu, ny, nx̂, nϵ, Hp, Hc = model.nu, model.ny, mpc.estim.nx̂, mpc.nϵ, mpc.Hp, mpc.Hc
605+
ng, nc, neq = length(mpc.con.i_g), mpc.con.nc, mpc.con.neq
606+
nZ̃, nU, nŶ, nX̂ = length(mpc.Z̃), Hp*nu, Hp*ny, Hp*nx̂
607+
nΔŨ, nUe, nŶe = nu*Hc + nϵ, nU + nu, nŶ + ny
608+
myNaN = convert(JNT, NaN)
609+
= fill(myNaN, nZ̃)
610+
ΔŨ = zeros(JNT, nΔŨ)
611+
x̂0end = zeros(JNT, nx̂)
612+
Ue, Ŷe = zeros(JNT, nUe), zeros(JNT, nŶe)
613+
U0, Ŷ0 = zeros(JNT, nU), zeros(JNT, nŶ)
614+
Û0, X̂0 = zeros(JNT, nU), zeros(JNT, nX̂)
615+
gc, g = zeros(JNT, nc), zeros(JNT, ng)
616+
geq = zeros(JNT, neq)
617+
618+
# WIP: still receive INVALID_MODEL once and a while, needs to investigate.
619+
620+
633621
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
634-
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
635-
ΔŨ = get_tmp(ΔŨ_cache, T)
636-
Ue, Ŷe = get_tmp(Ue_cache, T), get_tmp(Ŷe_cache, T)
637-
U0, Ŷ0 = get_tmp(U0_cache, T), get_tmp(Ŷ0_cache, T)
622+
update_simulations!(Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
638623
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)::T
639624
end
640-
function Jfunc_vec(Z̃arg::AbstractVector{T}) where T<:Real
641-
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
642-
ΔŨ = get_tmp(ΔŨ_cache, T)
643-
Ue, Ŷe = get_tmp(Ue_cache, T), get_tmp(Ŷe_cache, T)
644-
U0, Ŷ0 = get_tmp(U0_cache, T), get_tmp(Ŷ0_cache, T)
645-
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)::T
625+
function Jfunc_vec!(Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
626+
update_simulations!(Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
627+
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
646628
end
629+
630+
631+
632+
gfuncs = Vector{Function}(undef, ng)
633+
for i in eachindex(gfuncs)
634+
func_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
635+
update_simulations!(Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
636+
return g[i]::T
637+
end
638+
gfuncs[i] = func_i
639+
end
640+
function gfunc_vec!(g, Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
641+
update_simulations!(Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
642+
return g
643+
end
644+
645+
646+
647+
648+
649+
650+
651+
652+
Z̃_∇J = fill(myNaN, nZ̃)
653+
ΔŨ_∇J = zeros(JNT, nΔŨ)
654+
x̂0end_∇J = zeros(JNT, nx̂)
655+
Ue_∇J, Ŷe_∇J = zeros(JNT, nUe), zeros(JNT, nŶe)
656+
U0_∇J, Ŷ0_∇J = zeros(JNT, nU), zeros(JNT, nŶ)
657+
Û0_∇J, X̂0_∇J = zeros(JNT, nU), zeros(JNT, nX̂)
658+
gc_∇J, g_∇J = zeros(JNT, nc), zeros(JNT, ng)
659+
geq_∇J = zeros(JNT, neq)
660+
661+
662+
663+
664+
665+
666+
647667
Z̃_∇J = fill(myNaN, nZ̃)
648668
∇J = Vector{JNT}(undef, nZ̃) # gradient of objective J
649-
∇J_prep = prepare_gradient(Jfunc_vec, grad_backend, Z̃_∇J)
669+
∇J_context = (
670+
Cache(Z̃_∇J), Cache(ΔŨ_∇J), Cache(x̂0end_∇J),
671+
Cache(Ue_∇J), Cache(Ŷe_∇J),
672+
Cache(U0_∇J), Cache(Ŷ0_∇J),
673+
Cache(Û0_∇J), Cache(X̂0_∇J),
674+
Cache(gc_∇J), Cache(g_∇J), Cache(geq_∇J)
675+
)
676+
∇J_prep = prepare_gradient(Jfunc_vec!, grad_backend, Z̃_∇J, ∇J_context...)
650677
∇Jfunc! = if nZ̃ == 1
651-
function (Z̃arg::T) where T<:Real
678+
function (Z̃arg)
652679
Z̃_∇J .= Z̃arg
653-
gradient!(Jfunc_vec, ∇J, ∇J_prep, grad_backend, Z̃_∇J)
680+
gradient!(Jfunc_vec!, ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context...)
654681
return ∇J[begin] # univariate syntax, see JuMP.@operator doc
655682
end
656683
else
657684
function (∇J::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
658685
Z̃_∇J .= Z̃arg
659-
gradient!(Jfunc_vec, ∇J, ∇J_prep, grad_backend, Z̃_∇J)
686+
gradient!(Jfunc_vec!, ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context...)
660687
return ∇J # multivariate syntax, see JuMP.@operator doc
661688
end
662689
end
663-
# --------------------- inequality constraint functions -------------------------------
664-
gfuncs = Vector{Function}(undef, ng)
665-
for i in eachindex(gfuncs)
666-
func_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
667-
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
668-
g = get_tmp(g_cache, T)
669-
return g[i]::T
670-
end
671-
gfuncs[i] = func_i
672-
end
673-
function gfunc_vec!(g, Z̃vec::AbstractVector{T}) where T<:Real
674-
update_simulations!(Z̃vec, get_tmp(Z̃_cache, T))
675-
g .= get_tmp(g_cache, T)
676-
return g
677-
end
690+
691+
692+
693+
694+
695+
696+
697+
Z̃_∇g = fill(myNaN, nZ̃)
698+
ΔŨ_∇g = zeros(JNT, nΔŨ)
699+
x̂0end_∇g = zeros(JNT, nx̂)
700+
Ue_∇g, Ŷe_∇g = zeros(JNT, nUe), zeros(JNT, nŶe)
701+
U0_∇g, Ŷ0_∇g = zeros(JNT, nU), zeros(JNT, nŶ)
702+
Û0_∇g, X̂0_∇g = zeros(JNT, nU), zeros(JNT, nX̂)
703+
gc_∇g, g_∇g = zeros(JNT, nc), zeros(JNT, ng)
704+
geq_∇g = zeros(JNT, neq)
705+
706+
707+
708+
709+
678710
Z̃_∇g = fill(myNaN, nZ̃)
679711
g_vec = Vector{JNT}(undef, ng)
680712
∇g = Matrix{JNT}(undef, ng, nZ̃) # Jacobian of inequality constraints g
681-
∇g_prep = prepare_jacobian(gfunc_vec!, g_vec, jac_backend, Z̃_∇g)
713+
∇g_context = (
714+
Cache(Z̃_∇g), Cache(ΔŨ_∇g), Cache(x̂0end_∇g),
715+
Cache(Ue_∇g), Cache(Ŷe_∇g),
716+
Cache(U0_∇g), Cache(Ŷ0_∇g),
717+
Cache(Û0_∇g), Cache(X̂0_∇g),
718+
Cache(gc_∇g), Cache(geq_∇g)
719+
)
720+
∇g_prep = prepare_jacobian(gfunc_vec!, g_vec, jac_backend, Z̃_∇g, ∇g_context...)
682721
∇gfuncs! = Vector{Function}(undef, ng)
683722
for i in eachindex(∇gfuncs!)
684723
∇gfuncs![i] = if nZ̃ == 1
685724
function (Z̃arg::T) where T<:Real
686725
if isdifferent(Z̃arg, Z̃_∇g)
687726
Z̃_∇g .= Z̃arg
688-
jacobian!(gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g)
727+
jacobian!(
728+
gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...
729+
)
689730
end
690731
return ∇g[i, begin] # univariate syntax, see JuMP.@operator doc
691732
end
692733
else
693734
function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
694735
if isdifferent(Z̃arg, Z̃_∇g)
695736
Z̃_∇g .= Z̃arg
696-
jacobian!(gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g)
737+
jacobian!(
738+
gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...
739+
)
697740
end
698741
return ∇g_i .= @views ∇g[i, :] # multivariate syntax, see JuMP.@operator doc
699742
end
700743
end
701744
end
702745
# --------------------- equality constraint functions ---------------------------------
703746
geqfuncs = Vector{Function}(undef, neq)
747+
#=
704748
for i in eachindex(geqfuncs)
705749
func_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
706750
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
@@ -718,7 +762,9 @@ function get_optim_functions(
718762
geq_vec = Vector{JNT}(undef, neq)
719763
∇geq = Matrix{JNT}(undef, neq, nZ̃) # Jacobian of equality constraints geq
720764
∇geq_prep = prepare_jacobian(geqfunc_vec!, geq_vec, jac_backend, Z̃_∇geq)
765+
=#
721766
∇geqfuncs! = Vector{Function}(undef, neq)
767+
#=
722768
for i in eachindex(∇geqfuncs!)
723769
# only multivariate syntax, univariate is impossible since nonlinear equality
724770
# constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
@@ -731,6 +777,7 @@ function get_optim_functions(
731777
return ∇geq_i .= @views ∇geq[i, :]
732778
end
733779
end
780+
=#
734781
return Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
735782
end
736783

0 commit comments

Comments
 (0)