Skip to content

Commit 8592d48

Browse files
committed
wip: exact hessian
1 parent a09b32f commit 8592d48

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

src/ModelPredictiveControl.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ using Random: randn
77
using RecipesBase
88

99
using DifferentiationInterface: ADTypes.AbstractADType, AutoForwardDiff, AutoSparse
10-
using DifferentiationInterface: gradient!, jacobian!, prepare_gradient, prepare_jacobian
11-
using DifferentiationInterface: value_and_gradient!, value_and_jacobian!
10+
using DifferentiationInterface: gradient!, value_and_gradient!, prepare_gradient
11+
using DifferentiationInterface: jacobian!, value_and_jacobian!, prepare_jacobian
12+
using DifferentiationInterface: hessian!, value_gradient_and_hessian!, prepare_hessian
1213
using DifferentiationInterface: Constant, Cache
1314
using SparseConnectivityTracer: TracerSparsityDetector
1415
using SparseMatrixColorings: GreedyColoringAlgorithm, sparsity_pattern

src/controller/nonlinmpc.jl

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ const DEFAULT_NONLINMPC_JACSPARSE = AutoSparse(
77
sparsity_detector=TracerSparsityDetector(),
88
coloring_algorithm=GreedyColoringAlgorithm(),
99
)
10+
const DEFAULT_NONLINMPC_HESSIAN = DEFAULT_NONLINMPC_JACSPARSE
1011

1112
struct NonLinMPC{
1213
NT<:Real,
@@ -15,7 +16,8 @@ struct NonLinMPC{
1516
TM<:TranscriptionMethod,
1617
JM<:JuMP.GenericModel,
1718
GB<:AbstractADType,
18-
JB<:AbstractADType,
19+
JB<:AbstractADType,
20+
HB<:Union{AbstractADType, Nothing},
1921
PT<:Any,
2022
JEfunc<:Function,
2123
GCfunc<:Function
@@ -28,6 +30,7 @@ struct NonLinMPC{
2830
con::ControllerConstraint{NT, GCfunc}
2931
gradient::GB
3032
jacobian::JB
33+
hessian::HB
3134
oracle::Bool
3235
::Vector{NT}
3336
::Vector{NT}
@@ -117,9 +120,9 @@ struct NonLinMPC{
117120
nZ̃ = get_nZ(estim, transcription, Hp, Hc) +
118121
= zeros(NT, nZ̃)
119122
buffer = PredictiveControllerBuffer(estim, transcription, Hp, Hc, nϵ)
120-
mpc = new{NT, SE, CW, TM, JM, GB, JB, PT, JEfunc, GCfunc}(
123+
mpc = new{NT, SE, CW, TM, JM, GB, JB, HB, PT, JEfunc, GCfunc}(
121124
estim, transcription, optim, con,
122-
gradient, jacobian, oracle,
125+
gradient, jacobian, hessian, oracle,
123126
Z̃, ŷ,
124127
Hp, Hc, nϵ, nb,
125128
weights,
@@ -619,7 +622,7 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
619622
# ----------- common cache for all functions ----------------------------------------
620623
model = mpc.estim.model
621624
transcription = mpc.transcription
622-
grad, jac = mpc.gradient, mpc.jacobian
625+
grad, jac, hess = mpc.gradient, mpc.jacobian, mpc.hessian
623626
nu, ny, nx̂, nϵ = model.nu, model.ny, mpc.estim.nx̂, mpc.
624627
nk = get_nk(model, transcription)
625628
Hp, Hc = mpc.Hp, mpc.Hc
@@ -645,14 +648,28 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
645648
gi .= @views g[i_g]
646649
return nothing
647650
end
648-
Z̃_∇gi = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
651+
function ℓ_gi(Z̃_λ_gi, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, geq, g, gi)
652+
Z̃, λ = @views Z̃_λ_gi[begin:begin+nZ̃-1], Z̃_λ_gi[begin+nZ̃:end]
653+
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
654+
gi .= @views g[i_g]
655+
return dot(λ, gi)
656+
end
657+
Z̃_∇gi = fill(myNaN, nZ̃) # NaN to force update at first call
658+
Z̃_λ_gi = fill(myNaN, nZ̃ + ngi)
649659
∇gi_context = (
650660
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
651661
Cache(Û0), Cache(K0), Cache(X̂0),
652662
Cache(gc), Cache(geq), Cache(g)
653663
)
654-
∇gi_prep = prepare_jacobian(gi!, gi, jac, Z̃_∇gi, ∇gi_context...; strict)
655-
∇gi = init_diffmat(JNT, jac, ∇gi_prep, nZ̃, ngi)
664+
∇gi_prep = prepare_jacobian(gi!, gi, jac, Z̃_∇gi, ∇gi_context...; strict)
665+
∇²gi_context = (
666+
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
667+
Cache(Û0), Cache(K0), Cache(X̂0),
668+
Cache(gc), Cache(geq), Cache(g), Cache(gi)
669+
)
670+
∇²gi_prep = prepare_hessian(ℓ_gi, hess, Z̃_λ_gi, ∇²gi_context...; strict)
671+
∇gi = init_diffmat(JNT, jac, ∇gi_prep, nZ̃, ngi)
672+
∇²ℓ_gi = init_diffmat(JNT, hess, ∇²gi_prep, nZ̃ + ngi, nZ̃ + ngi)
656673
function update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
657674
if isdifferent(Z̃_arg, Z̃_∇gi)
658675
Z̃_∇gi .= Z̃_arg
@@ -668,23 +685,33 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
668685
update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
669686
return diffmat2vec!(∇gi_arg, ∇gi)
670687
end
688+
function ∇²gi_func!(∇²ℓ_arg, Z̃_arg, λ_arg)
689+
Z̃_λ_gi[1:begin:begin+nZ̃-1] .= Z̃_arg
690+
Z̃_λ_gi[[begin+nZ̃:end]] .= λ_arg
691+
hessian!(ℓ_gi, ∇²ℓ_gi, ∇²gi_prep, hess, Z̃_λ_gi, ∇²gi_context)
692+
return diffmat2vec!(∇²ℓ_arg, ∇²ℓ_gi)
693+
end
671694
gi_min = fill(-myInf, ngi)
672695
gi_max = zeros(JNT, ngi)
673-
∇gi_structure = init_diffstructure(∇gi)
696+
∇gi_structure = init_diffstructure(∇gi)
697+
∇²gi_structure = init_diffstructure(∇²ℓ_gi)
698+
display(∇²ℓ_gi)
674699
g_oracle = MOI.VectorNonlinearOracle(;
675700
dimension = nZ̃,
676701
l = gi_min,
677702
u = gi_max,
678703
eval_f = gi_func!,
679704
jacobian_structure = ∇gi_structure,
680-
eval_jacobian = ∇gi_func!
705+
eval_jacobian = ∇gi_func!,
706+
hessian_lagrangian_structure = ∇²gi_structure,
707+
eval_hessian_lagrangian = ∇²gi_func!
681708
)
682709
# ------------- equality constraints : nonlinear oracle ------------------------------
683710
function geq!(geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g)
684711
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
685712
return nothing
686713
end
687-
Z̃_∇geq = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
714+
Z̃_∇geq = fill(myNaN, nZ̃) # NaN to force update at first call
688715
∇geq_context = (
689716
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
690717
Cache(Û0), Cache(K0), Cache(X̂0),
@@ -722,7 +749,7 @@ function get_nonlinops(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where JNT<
722749
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
723750
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
724751
end
725-
Z̃_∇J = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
752+
Z̃_∇J = fill(myNaN, nZ̃) # NaN to force update at first call
726753
∇J_context = (
727754
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
728755
Cache(Û0), Cache(K0), Cache(X̂0),

0 commit comments

Comments
 (0)