Skip to content

Commit 8dbcdec

Browse files
committed
debug: only fills non-zero in ∇²f of @operator
1 parent 01bf244 commit 8dbcdec

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

src/controller/nonlinmpc.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ function get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where J
651651
if !isnothing(hess)
652652
∇²J_prep = prepare_hessian(J!, hess, Z̃_J, J_cache...; strict)
653653
∇²J = init_diffmat(JNT, hess, ∇²J_prep, nZ̃, nZ̃)
654+
∇²J_structure = lowertriangle_indices(init_diffstructure(∇²J))
654655
end
655656
update_objective! = if !isnothing(hess)
656657
function (J, ∇J, ∇²J, Z̃_J, Z̃_arg)
@@ -715,7 +716,8 @@ function get_nonlinobj_op(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where J
715716
else # multivariate syntax (see JuMP.@operator doc):
716717
function (∇²J_arg::AbstractMatrix{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
717718
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
718-
return fill_lowertriangle!(∇²J_arg, ∇²J)
719+
#println(typeof(∇²J_arg))
720+
return fill_diffstructure!(∇²J_arg, ∇²J, ∇²J_structure)
719721
end
720722
end
721723
if !isnothing(hess)
@@ -807,13 +809,13 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
807809
end
808810
function ∇gi_func!(∇gi_arg, Z̃_arg)
809811
update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
810-
return diffmat2vec!(∇gi_arg, ∇gi, ∇gi_structure)
812+
return fill_diffstructure!(∇gi_arg, ∇gi, ∇gi_structure)
811813
end
812814
function ∇²gi_func!(∇²ℓ_arg, Z̃_arg, λ_arg)
813815
Z̃_∇gi .= Z̃_arg
814816
λi .= λ_arg
815817
hessian!(ℓ_gi, ∇²ℓ_gi, ∇²gi_prep, hess, Z̃_∇gi, Constant(λi), ∇²gi_cache...)
816-
return diffmat2vec!(∇²ℓ_arg, ∇²ℓ_gi, ∇²gi_structure)
818+
return fill_diffstructure!(∇²ℓ_arg, ∇²ℓ_gi, ∇²gi_structure)
817819
end
818820
gi_min = fill(-myInf, ngi)
819821
gi_max = zeros(JNT, ngi)
@@ -870,13 +872,13 @@ function get_nonlincon_oracle(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JN
870872
end
871873
function ∇geq_func!(∇geq_arg, Z̃_arg)
872874
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃_arg)
873-
return diffmat2vec!(∇geq_arg, ∇geq, ∇geq_structure)
875+
return fill_diffstructure!(∇geq_arg, ∇geq, ∇geq_structure)
874876
end
875877
function ∇²geq_func!(∇²ℓ_arg, Z̃_arg, λ_arg)
876878
Z̃_∇geq .= Z̃_arg
877879
λeq .= λ_arg
878880
hessian!(ℓ_geq, ∇²ℓ_geq, ∇²geq_prep, hess, Z̃_∇geq, Constant(λeq), ∇²geq_cache...)
879-
return diffmat2vec!(∇²ℓ_arg, ∇²ℓ_geq, ∇²geq_structure)
881+
return fill_diffstructure!(∇²ℓ_arg, ∇²ℓ_geq, ∇²geq_structure)
880882
end
881883
geq_min = geq_max = zeros(JNT, neq)
882884
geq_oracle = MOI.VectorNonlinearOracle(;

src/estimator/mhe/construct.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ function get_nonlinobj_op(
13951395
∇²J_prep = prepare_hessian(J!, hess, Z̃_J, J_cache...; strict)
13961396
estim.Nk[] = 0
13971397
∇²J = init_diffmat(JNT, hess, ∇²J_prep, nZ̃, nZ̃)
1398+
∇²J_structure = lowertriangle_indices(init_diffstructure(∇²J))
13981399
end
13991400
update_objective! = if !isnothing(hess)
14001401
function (J, ∇J, ∇²J, Z̃_∇J, Z̃_arg)
@@ -1439,7 +1440,7 @@ function get_nonlinobj_op(
14391440
end
14401441
function ∇²J_func!(∇²J_arg::AbstractMatrix{T}, Z̃_arg::Vararg{T, N}) where {N, T<:Real}
14411442
update_objective!(J, ∇J, ∇²J, Z̃_J, Z̃_arg)
1442-
return fill_lowertriangle!(∇²J_arg, ∇²J)
1443+
return fill_diffstructure!(∇²J_arg, ∇²J, ∇²J_structure)
14431444
end
14441445
if !isnothing(hess)
14451446
@operator(optim, J_op, nZ̃, J_func, ∇J_func!, ∇²J_func!)
@@ -1525,13 +1526,13 @@ function get_nonlincon_oracle(
15251526
end
15261527
function ∇gi_func!(∇gi_arg, Z̃_arg)
15271528
update_con!(gi, ∇gi, Z̃_∇gi, Z̃_arg)
1528-
return diffmat2vec!(∇gi_arg, ∇gi, ∇gi_structure)
1529+
return fill_diffstructure!(∇gi_arg, ∇gi, ∇gi_structure)
15291530
end
15301531
function ∇²gi_func!(∇²ℓ_arg, Z̃_arg, λ_arg)
15311532
Z̃_∇gi .= Z̃_arg
15321533
λi .= λ_arg
15331534
hessian!(ℓ_gi, ∇²ℓ_gi, ∇²gi_prep, hess, Z̃_∇gi, Constant(λi), ∇²gi_cache...)
1534-
return diffmat2vec!(∇²ℓ_arg, ∇²ℓ_gi, ∇²gi_structure)
1535+
return fill_diffstructure!(∇²ℓ_arg, ∇²ℓ_gi, ∇²gi_structure)
15351536
end
15361537
gi_min = fill(-myInf, ngi)
15371538
gi_max = zeros(JNT, ngi)

src/general.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,28 @@ function fill_lowertriangle!(A::AbstractMatrix, B::AbstractMatrix)
8484
return A
8585
end
8686

87-
"Store the diff. matrix `A` in the vector `v` with list of nonzero indices `i_vec`"
88-
function diffmat2vec!(v::AbstractVector, A::AbstractMatrix, i_vec::Vector{Tuple{Int, Int}})
89-
for i in eachindex(v)
87+
"Store the diff. matrix `A` in the vector `v` with list of nonzero indices `i_vec`."
88+
function fill_diffstructure!(
89+
v::AbstractVector, A::AbstractMatrix, i_vec::Vector{Tuple{Int, Int}}
90+
)
91+
for i in eachindex(i_vec)
9092
i_A, j_A = i_vec[i]
9193
v[i] = A[i_A, j_A]
9294
end
9395
return v
9496
end
9597

98+
"Store the diff. matrix `A` in the matrix `T` with list of nonzero indices `i_vec`."
99+
function fill_diffstructure!(
100+
T::AbstractMatrix, A::AbstractMatrix, i_vec::Vector{Tuple{Int, Int}}
101+
)
102+
for i in eachindex(i_vec)
103+
i_A, j_A = i_vec[i]
104+
T[i_A, j_A] = A[i_A, j_A]
105+
end
106+
return T
107+
end
108+
96109
backend_str(backend::AbstractADType) = string(nameof(typeof(backend)))
97110
backend_str(backend::Nothing) = "nothing"
98111
function backend_str(backend::AutoSparse)

0 commit comments

Comments
 (0)