Skip to content

Commit 1ed36a8

Browse files
committed
added: use DI.Cache in MovingHorizonEstimator
1 parent 22a6b9c commit 1ed36a8

File tree

3 files changed

+61
-52
lines changed

3 files changed

+61
-52
lines changed

src/controller/nonlinmpc.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,9 @@ function get_optim_functions(
590590
jac_backend ::AbstractADType
591591
) where JNT<:Real
592592
model, transcription = mpc.estim.model, mpc.transcription
593+
#TODO: initialize jacobian as sparsed if it's the case?
594+
#TODO: fix type of all cache to ::Vector{JNT} (verify performance difference with and w/o)
595+
#TODO: mêmes choses pour le MHE
593596
# --------------------- update simulation function ------------------------------------
594597
function update_simulations!(Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
595598
U0 = getU0!(U0, mpc, Z̃)

src/estimator/mhe/construct.jl

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,100 +1333,97 @@ function get_optim_functions(
13331333
jac_backend::AbstractADType
13341334
) where {JNT <: Real}
13351335
model, con = estim.model, estim.con
1336-
nx̂, nym, nŷ, nu, nϵ, He = estim.nx̂, estim.nym, model.ny, model.nu, estim.nϵ, estim.He
1337-
nV̂, nX̂, ng, nZ̃ = He*nym, He*nx̂, length(con.i_g), length(estim.Z̃)
1338-
Ncache = nZ̃ + 3
1339-
myNaN = convert(JNT, NaN) # fill Z̃ with NaNs to force update_simulations! at 1st call:
1340-
# ---------------------- differentiation cache ---------------------------------------
1341-
Z̃_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(fill(myNaN, nZ̃), Ncache)
1342-
V̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nV̂), Ncache)
1343-
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), Ncache)
1344-
X̂0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nX̂), Ncache)
1345-
x̄_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Ncache)
1346-
û0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Ncache)
1347-
ŷ0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŷ), Ncache)
13481336
# --------------------- update simulation function ------------------------------------
1349-
function update_simulations!(
1350-
Z̃arg::Union{NTuple{N, T}, AbstractVector{T}}, Z̃cache
1351-
) where {N, T <:Real}
1352-
if isdifferent(Z̃cache, Z̃arg)
1353-
for i in eachindex(Z̃cache)
1354-
# Z̃cache .= Z̃arg is type unstable with Z̃arg::NTuple{N, FowardDiff.Dual}
1355-
Z̃cache[i] = Z̃arg[i]
1356-
end
1357-
= Z̃cache
1358-
ϵ = (nϵ 0) ? Z̃[begin] : zero(T) # ϵ = 0 if Cwt=Inf (meaning: no relaxation)
1359-
V̂, X̂0 = get_tmp(V̂_cache, T), get_tmp(X̂0_cache, T)
1360-
û0, ŷ0 = get_tmp(û0_cache, T), get_tmp(ŷ0_cache, T)
1361-
g = get_tmp(g_cache, T)
1362-
V̂, X̂0 = predict!(V̂, X̂0, û0, ŷ0, estim, model, Z̃)
1363-
g = con_nonlinprog!(g, estim, model, X̂0, V̂, ϵ)
1364-
end
1337+
function update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1338+
V̂, X̂0 = predict!(V̂, X̂0, û0, ŷ0, estim, model, Z̃)
1339+
ϵ = getϵ(estim, Z̃)
1340+
g = con_nonlinprog!(g, estim, model, X̂0, V̂, ϵ)
13651341
return nothing
13661342
end
1343+
# ---------- common cache for Jfunc, gfuncs called with floats ------------------------
1344+
nx̂, nym, nŷ, nu, nϵ, He = estim.nx̂, estim.nym, model.ny, model.nu, estim.nϵ, estim.He
1345+
nV̂, nX̂, ng, nZ̃ = He*nym, He*nx̂, length(con.i_g), length(estim.Z̃)
1346+
myNaN = convert(JNT, NaN)
1347+
= fill(myNaN, nZ̃) # NaN to force update_simulations! at first call
1348+
V̂, X̂0 = zeros(JNT, nV̂), zeros(JNT, nX̂)
1349+
û0, ŷ0 = zeros(JNT, nu), zeros(JNT, nŷ)
1350+
g = zeros(JNT, ng)
1351+
= zeros(JNT, nx̂)
13671352
# --------------------- objective functions -------------------------------------------
13681353
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
1369-
= get_tmp(Z̃_cache, T)
1370-
update_simulations!(Z̃arg, Z̃)
1371-
x̄, V̂ = get_tmp(x̄_cache, T), get_tmp(V̂_cache, T)
1354+
if isdifferent(Z̃arg, Z̃)
1355+
Z̃ .= Z̃arg
1356+
update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1357+
end
13721358
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)::T
13731359
end
1374-
function Jfunc_vec(Z̃arg::AbstractVector{T}) where T<:Real
1375-
= get_tmp(Z̃_cache, T)
1376-
update_simulations!(Z̃arg, Z̃)
1377-
x̄, V̂ = get_tmp(x̄_cache, T), get_tmp(V̂_cache, T)
1378-
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)::T
1360+
function Jfunc!(Z̃, V̂, X̂0, û0, ŷ0, g, x̄)
1361+
update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1362+
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)
13791363
end
13801364
Z̃_∇J = fill(myNaN, nZ̃)
13811365
∇J = Vector{JNT}(undef, nZ̃) # gradient of objective J
1382-
∇J_prep = prepare_gradient(Jfunc_vec, grad_backend, Z̃_∇J)
1366+
∇J_context = (
1367+
Cache(V̂), Cache(X̂0),
1368+
Cache(û0), Cache(ŷ0),
1369+
Cache(g),
1370+
Cache(x̄),
1371+
)
1372+
∇J_prep = prepare_gradient(Jfunc!, grad_backend, Z̃_∇J, ∇J_context...)
13831373
∇Jfunc! = if nZ̃ == 1
1384-
function (Z̃arg::T) where T<:Real
1374+
function (Z̃arg)
13851375
Z̃_∇J .= Z̃arg
1386-
gradient!(Jfunc_vec, ∇J, ∇J_prep, grad_backend, Z̃_∇J)
1376+
gradient!(Jfunc!, ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context...)
13871377
return ∇J[begin] # univariate syntax, see JuMP.@operator doc
13881378
end
13891379
else
13901380
function (∇J::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
13911381
Z̃_∇J .= Z̃arg
1392-
gradient!(Jfunc_vec, ∇J, ∇J_prep, grad_backend, Z̃_∇J)
1382+
gradient!(Jfunc!, ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context...)
13931383
return ∇J # multivariate syntax, see JuMP.@operator doc
13941384
end
13951385
end
13961386
# --------------------- inequality constraint functions -------------------------------
13971387
gfuncs = Vector{Function}(undef, ng)
13981388
for i in eachindex(gfuncs)
1399-
func_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
1400-
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
1401-
g = get_tmp(g_cache, T)
1389+
gfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
1390+
if isdifferent(Z̃arg, Z̃)
1391+
Z̃ .= Z̃arg
1392+
update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
1393+
end
14021394
return g[i]::T
14031395
end
1404-
gfuncs[i] = func_i
1396+
gfuncs[i] = gfunc_i
14051397
end
1406-
function gfunc_vec!(g, Z̃vec::AbstractVector{T}) where T<:Real
1407-
update_simulations!(Z̃vec, get_tmp(Z̃_cache, T))
1408-
g .= get_tmp(g_cache, T)
1409-
return g
1398+
function gfunc!(g, Z̃, V̂, X̂0, û0, ŷ0)
1399+
return update_simulations!(Z̃, V̂, X̂0, û0, ŷ0, g)
14101400
end
14111401
Z̃_∇g = fill(myNaN, nZ̃)
1412-
g_vec = Vector{JNT}(undef, ng)
14131402
∇g = Matrix{JNT}(undef, ng, nZ̃) # Jacobian of inequality constraints g
1414-
∇g_prep = prepare_jacobian(gfunc_vec!, g_vec, jac_backend, Z̃_∇g)
1403+
∇g_context = (
1404+
Cache(V̂), Cache(X̂0),
1405+
Cache(û0), Cache(ŷ0),
1406+
)
1407+
# temporarily enable all the inequality constraints for sparsity pattern detection:
1408+
i_g_old = copy(estim.con.i_g)
1409+
estim.con.i_g .= true
1410+
∇g_prep = prepare_jacobian(gfunc!, g, jac_backend, Z̃_∇g, ∇g_context...)
1411+
estim.con.i_g .= i_g_old
14151412
∇gfuncs! = Vector{Function}(undef, ng)
14161413
for i in eachindex(∇gfuncs!)
14171414
∇gfuncs![i] = if nZ̃ == 1
14181415
function (Z̃arg::T) where T<:Real
14191416
if isdifferent(Z̃arg, Z̃_∇g)
14201417
Z̃_∇g .= Z̃arg
1421-
jacobian!(gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g)
1418+
jacobian!(gfunc!, g, ∇g, ∇g_prep, jac_backend, Z̃_∇g. ∇g_context...)
14221419
end
14231420
return ∇g[i, begin] # univariate syntax, see JuMP.@operator doc
14241421
end
14251422
else
14261423
function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
14271424
if isdifferent(Z̃arg, Z̃_∇g)
14281425
Z̃_∇g .= Z̃arg
1429-
jacobian!(gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g)
1426+
jacobian!(gfunc!, g, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...)
14301427
end
14311428
return ∇g_i .= @views ∇g[i, :] # multivariate syntax, see JuMP.@operator doc
14321429
end

src/estimator/mhe/execute.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,15 @@ function getinfo(estim::MovingHorizonEstimator{NT}) where NT<:Real
160160
return info
161161
end
162162

163+
"""
164+
getϵ(estim::MovingHorizonEstimator, Z̃) -> ϵ
165+
166+
Get the slack `ϵ` from the decision vector `Z̃` if present, otherwise return 0.
167+
"""
168+
function getϵ(estim::MovingHorizonEstimator, Z̃::AbstractVector{NT}) where NT<:Real
169+
return estim. 0 ? Z̃[begin] : zero(NT)
170+
end
171+
163172
"""
164173
add_data_windows!(estim::MovingHorizonEstimator, y0m, d0, u0=estim.lastu0) -> ismoving
165174

0 commit comments

Comments
 (0)