Skip to content

Commit b28b702

Browse files
committed
changed: update_prediction! methods are now not nested
It is no longer necessary with `DI.Constant` context.
1 parent 30b2e4d commit b28b702

File tree

4 files changed

+57
-34
lines changed

4 files changed

+57
-34
lines changed

src/controller/execute.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,16 @@ function getinfo(mpc::PredictiveController{NT}) where NT<:Real
123123
x̂0end = similar(mpc.estim.x̂0)
124124
Ue, Ŷe = Vector{NT}(undef, nUe), Vector{NT}(undef, nŶe)
125125
U0, Ŷ0 = similar(mpc.Uop), similar(mpc.Yop)
126-
X̂0, Û0 = Vector{NT}(undef, nX̂0), Vector{NT}(undef, nÛ0)
126+
Û0, X̂0 = Vector{NT}(undef, nÛ0), Vector{NT}(undef, nX̂0)
127127
U, Ŷ = similar(mpc.Uop), similar(mpc.Yop)
128-
Ŷs = similar(mpc.Yop)
129128
U0 = getU0!(U0, mpc, Z̃)
130-
ΔŨ = getΔŨ!(ΔŨ, mpc, mpc.transcription, Z̃)
129+
ΔŨ = getΔŨ!(ΔŨ, mpc, transcription, Z̃)
131130
Ŷ0, x̂0end = predict!(Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
132131
Ue, Ŷe = extended_vectors!(Ue, Ŷe, mpc, U0, Ŷ0)
133132
U .= U0 .+ mpc.Uop
134133
Ŷ .= Ŷ0 .+ mpc.Yop
135134
J = obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
135+
Ŷs = similar(mpc.Yop)
136136
predictstoch!(Ŷs, mpc, mpc.estim)
137137
info[:ΔU] = Z̃[1:mpc.Hc*model.nu]
138138
info[] = getϵ(mpc, Z̃)
@@ -370,6 +370,9 @@ function obj_nonlinprog!(
370370
return JR̂y + JΔŨ + JR̂u + E_JE
371371
end
372372

373+
"No custom nonlinear constraints `gc` by default, return `gc` unchanged."
374+
con_custom!(gc, ::PredictiveController, _ , _, _ ) = gc
375+
373376
"By default, the economic term is zero."
374377
function obj_econ(::PredictiveController, ::SimModel, _ , ::AbstractVector{NT}) where NT
375378
return zero(NT)

src/controller/nonlinmpc.jl

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -589,19 +589,6 @@ function get_optim_functions(
589589
grad_backend::AbstractADType,
590590
jac_backend ::AbstractADType
591591
) where JNT<:Real
592-
# ------ update simulation function (all args after `mpc` are mutated) ----------------
593-
function update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
594-
model, transcription = mpc.estim.model, mpc.transcription
595-
U0 = getU0!(U0, mpc, Z̃)
596-
ΔŨ = getΔŨ!(ΔŨ, mpc, transcription, Z̃)
597-
Ŷ0, x̂0end = predict!(Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
598-
Ue, Ŷe = extended_vectors!(Ue, Ŷe, mpc, U0, Ŷ0)
599-
ϵ = getϵ(mpc, Z̃)
600-
gc = con_custom!(gc, mpc, Ue, Ŷe, ϵ)
601-
g = con_nonlinprog!(g, mpc, model, transcription, x̂0end, Ŷ0, gc, ϵ)
602-
geq = con_nonlinprogeq!(geq, X̂0, Û0, mpc, model, transcription, U0, Z̃)
603-
return nothing
604-
end
605592
# ----- common cache for Jfunc, gfuncs, geqfuncs called with floats -------------------
606593
model = mpc.estim.model
607594
nu, ny, nx̂, nϵ, Hp, Hc = model.nu, model.ny, mpc.estim.nx̂, mpc.nϵ, mpc.Hp, mpc.Hc
@@ -622,12 +609,12 @@ function get_optim_functions(
622609
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
623610
if isdifferent(Z̃arg, Z̃)
624611
Z̃ .= Z̃arg
625-
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
612+
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq, mpc, Z̃)
626613
end
627614
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)::T
628615
end
629616
function Jfunc!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
630-
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
617+
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq, mpc, Z̃)
631618
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
632619
end
633620
Z̃_∇J = fill(myNaN, nZ̃)
@@ -658,14 +645,14 @@ function get_optim_functions(
658645
gfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
659646
if isdifferent(Z̃arg, Z̃)
660647
Z̃ .= Z̃arg
661-
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
648+
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq, mpc, Z̃)
662649
end
663650
return g[i]::T
664651
end
665652
gfuncs[i] = gfunc_i
666653
end
667654
function gfunc!(g, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
668-
return update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
655+
return update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq, mpc, Z̃)
669656
end
670657
Z̃_∇g = fill(myNaN, nZ̃)
671658
∇g_context = (
@@ -706,14 +693,14 @@ function get_optim_functions(
706693
geqfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
707694
if isdifferent(Z̃arg, Z̃)
708695
Z̃ .= Z̃arg
709-
update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
696+
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq, mpc, Z̃)
710697
end
711698
return geq[i]::T
712699
end
713700
geqfuncs[i] = geqfunc_i
714701
end
715702
function geqfunc!(geq, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g)
716-
return update_simulations!(Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
703+
return update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq, mpc, Z̃)
717704
end
718705
Z̃_∇geq = fill(myNaN, nZ̃)
719706
∇geq_context = (
@@ -743,6 +730,31 @@ function get_optim_functions(
743730
return Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
744731
end
745732

733+
"""
734+
update_predictions!(
735+
ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq,
736+
mpc::PredictiveController, Z̃
737+
) -> nothing
738+
739+
Update in-place all vectors for the predictions of `mpc` controller at decision vector `Z̃`.
740+
741+
The method mutates all the arguments before the `mpc` argument.
742+
"""
743+
function update_predictions!(
744+
ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq, mpc::PredictiveController, Z̃
745+
)
746+
model, transcription = mpc.estim.model, mpc.transcription
747+
U0 = getU0!(U0, mpc, Z̃)
748+
ΔŨ = getΔŨ!(ΔŨ, mpc, transcription, Z̃)
749+
Ŷ0, x̂0end = predict!(Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
750+
Ue, Ŷe = extended_vectors!(Ue, Ŷe, mpc, U0, Ŷ0)
751+
ϵ = getϵ(mpc, Z̃)
752+
gc = con_custom!(gc, mpc, Ue, Ŷe, ϵ)
753+
g = con_nonlinprog!(g, mpc, model, transcription, x̂0end, Ŷ0, gc, ϵ)
754+
geq = con_nonlinprogeq!(geq, X̂0, Û0, mpc, model, transcription, U0, Z̃)
755+
return nothing
756+
end
757+
746758
@doc raw"""
747759
con_custom!(gc, mpc::NonLinMPC, Ue, Ŷe, ϵ) -> gc
748760

src/estimator/mhe/construct.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,14 +1332,6 @@ function get_optim_functions(
13321332
grad_backend::AbstractADType,
13331333
jac_backend::AbstractADType
13341334
) where {JNT <: Real}
1335-
# -------- update simulation function (all args after `estim` are mutated) ------------
1336-
function update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
1337-
model = estim.model
1338-
V̂, X̂0 = predict!(V̂, X̂0, û0, ŷ0, estim, model, Z̃)
1339-
ϵ = getϵ(estim, Z̃)
1340-
g = con_nonlinprog!(g, estim, model, X̂0, V̂, ϵ)
1341-
return nothing
1342-
end
13431335
# ---------- common cache for Jfunc, gfuncs called with floats ------------------------
13441336
model, con = estim.model, estim.con
13451337
nx̂, nym, nŷ, nu, nϵ, He = estim.nx̂, estim.nym, model.ny, model.nu, estim.nϵ, estim.He
@@ -1355,12 +1347,12 @@ function get_optim_functions(
13551347
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
13561348
if isdifferent(Z̃arg, Z̃)
13571349
Z̃ .= Z̃arg
1358-
update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
1350+
update_prediction!(V̂, X̂0, û0, ŷ0, g, estim, Z̃)
13591351
end
13601352
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)::T
13611353
end
13621354
function Jfunc!(Z̃, estim, V̂, X̂0, û0, ŷ0, g, x̄)
1363-
update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
1355+
update_prediction!(V̂, X̂0, û0, ŷ0, g, estim, Z̃)
13641356
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)
13651357
end
13661358
Z̃_∇J = fill(myNaN, nZ̃)
@@ -1392,14 +1384,14 @@ function get_optim_functions(
13921384
gfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
13931385
if isdifferent(Z̃arg, Z̃)
13941386
Z̃ .= Z̃arg
1395-
update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
1387+
update_prediction!(V̂, X̂0, û0, ŷ0, g, estim, Z̃)
13961388
end
13971389
return g[i]::T
13981390
end
13991391
gfuncs[i] = gfunc_i
14001392
end
14011393
function gfunc!(g, Z̃, estim, V̂, X̂0, û0, ŷ0)
1402-
return update_simulations!(Z̃, estim, V̂, X̂0, û0, ŷ0, g)
1394+
return update_prediction!(V̂, X̂0, û0, ŷ0, g, estim, Z̃)
14031395
end
14041396
Z̃_∇g = fill(myNaN, nZ̃)
14051397
∇g_context = (

src/estimator/mhe/execute.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,22 @@ function predict!(V̂, X̂0, û0, ŷ0, estim::MovingHorizonEstimator, model::S
591591
return V̂, X̂0
592592
end
593593

594+
595+
"""
596+
update_predictions!(V̂, X̂0, û0, ŷ0, g, estim::MovingHorizonEstimator, Z̃)
597+
598+
Update in-place the vectors for the predictions of `estim` estimator at decision vector `Z̃`.
599+
600+
The method mutates all the arguments before `estim` argument.
601+
"""
602+
function update_prediction!(V̂, X̂0, û0, ŷ0, g, estim::MovingHorizonEstimator, Z̃)
603+
model = estim.model
604+
V̂, X̂0 = predict!(V̂, X̂0, û0, ŷ0, estim, model, Z̃)
605+
ϵ = getϵ(estim, Z̃)
606+
g = con_nonlinprog!(g, estim, model, X̂0, V̂, ϵ)
607+
return nothing
608+
end
609+
594610
"""
595611
con_nonlinprog!(g, estim::MovingHorizonEstimator, model::SimModel, X̂0, V̂, ϵ)
596612

0 commit comments

Comments
 (0)