Skip to content

Commit cacb2be

Browse files
committed
added: derivatives in getinfo(::MovingHorizonEstimator)
1 parent ea4c42a commit cacb2be

File tree

3 files changed

+92
-11
lines changed

3 files changed

+92
-11
lines changed

src/controller/nonlinmpc.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -530,13 +530,18 @@ For [`NonLinMPC`](@ref), add `:sol`, the custom nonlinear objective `:JE`, the c
530530
constraint `:gc`, and the various derivatives.
531531
"""
532532
function addinfo!(info, mpc::NonLinMPC{NT}) where NT<:Real
533+
# --- variables specific to NonLinMPC ---
533534
U, Ŷ, D̂, ŷ, d, ϵ = info[:U], info[:Ŷ], info[:D̂], info[:ŷ], info[:d], info[]
534535
Ue = [U; U[(end - mpc.estim.model.nu + 1):end]]
535536
Ŷe = [ŷ; Ŷ]
536537
D̂e = [d; D̂]
537538
JE = mpc.JE(Ue, Ŷe, D̂e, mpc.p)
538539
LHS = Vector{NT}(undef, mpc.con.nc)
539540
mpc.con.gc!(LHS, Ue, Ŷe, D̂e, mpc.p, ϵ)
541+
info[:JE] = JE
542+
info[:gc] = LHS
543+
info[:sol] = JuMP.solution_summary(mpc.optim, verbose=true)
544+
# --- derivatives ---
540545
model, optim = mpc.estim.model, mpc.optim
541546
transcription = mpc.transcription
542547
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
@@ -568,15 +573,19 @@ function addinfo!(info, mpc::NonLinMPC{NT}) where NT<:Real
568573
else
569574
∇J, ∇²J = gradient(J!, mpc.gradient, mpc.Z̃, J_cache...), nothing
570575
end
571-
JNT = typeof(mpc.optim).parameters[1]
576+
JNT = typeof(optim).parameters[1]
572577
nonlin_constraints = JuMP.all_constraints(
573578
optim, JuMP.Vector{JuMP.VariableRef}, MOI.VectorNonlinearOracle{JNT}
574579
)
575-
g_con, geq_func = nonlin_constraints
576-
λ, λeq = JuMP.dual.(g_con), JuMP.dual.(geq_func)
580+
g_con, geq_con = nonlin_constraints
581+
display(g_con)
582+
display(geq_con)
583+
λ, λeq = JuMP.dual.(g_con), JuMP.dual.(geq_con)
584+
println(JuMP.dual.(JuMP.FixBoundRef.(optim[:Z̃var])))
577585
display(λ)
578586
display(λeq)
579-
g_cache = (
587+
588+
∇g_cache = (
580589
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
581590
Cache(Û0), Cache(K0), Cache(X̂0),
582591
Cache(gc), Cache(geq),
@@ -585,7 +594,7 @@ function addinfo!(info, mpc::NonLinMPC{NT}) where NT<:Real
585594
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
586595
return nothing
587596
end
588-
∇g = jacobian(g!, g, mpc.jacobian, mpc.Z̃, g_cache...)
597+
∇g = jacobian(g!, g, mpc.jacobian, mpc.Z̃, g_cache...)
589598
#=
590599
if !isnothing(mpc.hessian)
591600
function ℓ_g(Z̃, λ, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g)
@@ -614,15 +623,13 @@ function addinfo!(info, mpc::NonLinMPC{NT}) where NT<:Real
614623
∇geq = jacobian(geq!, geq, mpc.jacobian, mpc.Z̃, geq_cache...)
615624
∇²ℓgeq = nothing # TODO: implement later
616625

617-
info[:JE] = JE
618-
info[:gc] = LHS
626+
619627
info[:∇J] = ∇J
620628
info[:∇²J] = ∇²J
621629
info[:∇g] = ∇g
622630
info[:∇²ℓg] = ∇²ℓg
623631
info[:∇geq] = ∇geq
624632
info[:∇²ℓgeq] = ∇²ℓgeq
625-
info[:sol] = JuMP.solution_summary(mpc.optim, verbose=true)
626633
# --- non-Unicode fields ---
627634
info[:nablaJ] = ∇J
628635
info[:nabla2J] = ∇²J

src/estimator/mhe/construct.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,9 +1504,7 @@ function get_nonlincon_oracle(
15041504
return dot(λi, gi)
15051505
end
15061506
Z̃_∇gi = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
1507-
∇gi_cache = (
1508-
Cache(V̂), Cache(X̂0), Cache(û0), Cache(k0), Cache(ŷ0), Cache(g)
1509-
)
1507+
∇gi_cache = (Cache(V̂), Cache(X̂0), Cache(û0), Cache(k0), Cache(ŷ0), Cache(g))
15101508
# temporarily "fill" the estimation window for the preparation of the gradient:
15111509
estim.Nk[] = He
15121510
∇gi_prep = prepare_jacobian(gi!, gi, jac, Z̃_∇gi, ∇gi_cache...; strict)

src/estimator/mhe/execute.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,85 @@ function getinfo(estim::MovingHorizonEstimator{NT}) where NT<:Real
163163
info[:Yhatm] = info[:Ŷm]
164164
# --- deprecated fields ---
165165
info[] = info[]
166+
info = addinfo!(info, estim, model)
166167
return info
167168
end
168169

170+
171+
"""
172+
addinfo!(info, estim::MovingHorizonEstimator, model::NonLinModel)
173+
174+
For [`NonLinModel`](@ref), add the various derivatives.
175+
"""
176+
function addinfo!(
177+
info, estim::MovingHorizonEstimator{NT}, model::NonLinModel
178+
) where NT <:Real
179+
# --- derivatives ---
180+
optim, con = estim.optim, estim.con
181+
nx̂, nym, nŷ, nu, nk = estim.nx̂, estim.nym, model.ny, model.nu, model.nk
182+
He = estim.He
183+
ng = length(con.i_g)
184+
nV̂, nX̂, ng = He*nym, He*nx̂, length(con.i_g)
185+
V̂, X̂0 = zeros(NT, nV̂), zeros(NT, nX̂)
186+
k0 = zeros(NT, nk)
187+
û0, ŷ0 = zeros(NT, nu), zeros(NT, nŷ)
188+
g = zeros(NT, ng)
189+
= zeros(NT, nx̂)
190+
J_cache = (
191+
Cache(V̂), Cache(X̂0), Cache(û0), Cache(k0), Cache(ŷ0),
192+
Cache(g),
193+
Cache(x̄),
194+
)
195+
function J!(Z̃, V̂, X̂0, û0, k0, ŷ0, g, x̄)
196+
update_prediction!(V̂, X̂0, û0, k0, ŷ0, g, estim, Z̃)
197+
return obj_nonlinprog!(x̄, estim, model, V̂, Z̃)
198+
end
199+
if !isnothing(estim.hessian)
200+
_, ∇J, ∇²J = value_gradient_and_hessian(J!, estim.hessian, estim.Z̃, J_cache...)
201+
else
202+
∇J, ∇²J = gradient(J!, estim.gradient, estim.Z̃, J_cache...), nothing
203+
end
204+
JNT = typeof(optim).parameters[1]
205+
nonlin_constraints = JuMP.all_constraints(
206+
optim, JuMP.Vector{JuMP.VariableRef}, MOI.VectorNonlinearOracle{JNT}
207+
)
208+
g_con = nonlin_constraints[1]
209+
λ = JuMP.dual.(g_con)
210+
display(λ)
211+
∇g_cache = (Cache(V̂), Cache(X̂0), Cache(û0), Cache(k0), Cache(ŷ0))
212+
function g!(g, Z̃, V̂, X̂0, û0, k0, ŷ0)
213+
update_prediction!(V̂, X̂0, û0, k0, ŷ0, g, estim, Z̃)
214+
return nothing
215+
end
216+
∇g = jacobian(g!, g, estim.jacobian, estim.Z̃, ∇g_cache...)
217+
#=
218+
if !isnothing(estim.hessian)
219+
function ℓ_g(Z̃, λ, V̂, X̂0, û0, k0, ŷ0, g)
220+
update_prediction!(V̂, X̂0, û0, k0, ŷ0, g, estim, Z̃)
221+
return dot(λ, g)
222+
end
223+
∇²g_cache = (Cache(V̂), Cache(X̂0), Cache(û0), Cache(k0), Cache(ŷ0), Cache(g))
224+
∇²ℓg = hessian(ℓ_g, estim.hessian, estim.Z̃, Constant(λ), ∇²g_cache...)
225+
else
226+
∇²ℓg = nothing
227+
end
228+
=# ∇²ℓg = nothing #TODO: delete this line when enabling the above block
229+
230+
info[:∇J] = ∇J
231+
info[:∇²J] = ∇²J
232+
info[:∇g] = ∇g
233+
info[:∇²ℓg] = ∇²ℓg
234+
# --- non-Unicode fields ---
235+
info[:nablaJ] = ∇J
236+
info[:nabla2J] = ∇²J
237+
info[:nablag] = ∇g
238+
info[:nabla2lg] = ∇²ℓg
239+
return info
240+
end
241+
242+
"Nothing to add in the `info` dict for [`LinModel`](@ref)."
243+
addinfo!(info, ::MovingHorizonEstimator, ::LinModel) = info
244+
169245
"""
170246
getε(estim::MovingHorizonEstimator, Z̃) -> ε
171247

0 commit comments

Comments
 (0)