Skip to content

Commit 3cb1c19

Browse files
committed
type stable usage of hessian!
1 parent 8803495 commit 3cb1c19

File tree

3 files changed

+41
-27
lines changed

3 files changed

+41
-27
lines changed

src/ARCTR.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ARCTR
33
# stdlib
44
using LinearAlgebra, SparseArrays
55
# JSO
6-
using Krylov, NLPModels, SparseMatricesCOO, SolverCore
6+
using Krylov, LinearOperators, NLPModels, SparseMatricesCOO, SolverCore
77
# Stopping
88
using Stopping, StoppingInterface
99

src/main.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function preprocess(
8686
α,
8787
)
8888
max_hprod = stp.meta.max_cntrs[:neval_hprod]
89-
Hx = stp.current_state.Hx
89+
Hx = workspace.Hstruct.H
9090
PData = preprocess(PData, Hx, ∇f, norm_∇f, neval_hprod(stp.pb), max_hprod, α)
9191
return PData
9292
end
@@ -101,7 +101,7 @@ function compute_direction(
101101
solve_model,
102102
)
103103
max_hprod = stp.meta.max_cntrs[:neval_hprod]
104-
Hx = stp.current_state.Hx
104+
Hx = workspace.Hstruct.H
105105
return solve_model(PData, Hx, ∇f, norm_∇f, neval_hprod(stp.pb), max_hprod, α)
106106
end
107107

@@ -162,11 +162,11 @@ end
162162
163163
Update `Δq = -(∇f + 0.5 * (Hx * d)) ⋅ d` in-place.
164164
"""
165-
function compute_Δq(workspace, Hx, d, ∇f) # -(∇f + 0.5 * (nlp_at_x.Hx * d)) ⋅ d
166-
mul!(workspace.Hd, Hx, d)
167-
workspace.Hd .*= 0.5
168-
workspace.Hd .+= ∇f
169-
return -dot(workspace.Hd, d)
165+
function compute_Δq(workspace, Hx, d, ∇f)
166+
mul!(workspace.Hd, Hx, d)
167+
workspace.Hd .*= 0.5
168+
workspace.Hd .+= ∇f
169+
return -dot(workspace.Hd, d)
170170
end
171171

172172
function TRARC(
@@ -180,8 +180,8 @@ function TRARC(
180180
kwargs...,
181181
) where {Pb,M,SRC,MStp,LoS,S,T,Hess,ParamData}
182182
nlp, nlp_at_x = nlp_stop.pb, nlp_stop.current_state
183-
xt, xtnext, d, ∇f, ∇fnext =
184-
workspace.xt, workspace.xtnext, workspace.d, workspace.∇f, workspace.∇fnext
183+
xt, xtnext, ∇f, ∇fnext = workspace.xt, workspace.xtnext, workspace.∇f, workspace.∇fnext
184+
d, Hx = workspace.d, workspace.Hstruct.H
185185

186186
α = TR.α₀
187187
max_unsuccinarow = TR.max_unsuccinarow
@@ -197,7 +197,8 @@ function TRARC(
197197

198198
Stopping._smart_update!(nlp_at_x, x = xt, fx = ft, gx = ∇f)
199199
OK = start!(nlp_stop)
200-
!OK && Stopping._smart_update!(nlp_at_x, Hx = hessian!(workspace, nlp, xt))
200+
Hx = hessian!(workspace, nlp, xt)
201+
!OK && Stopping._smart_update!(nlp_at_x, Hx = Hx)
201202

202203
iter = 0 # counter different than stop count
203204
succ, unsucc, verysucc, unsuccinarow = 0, 0, 0, 0
@@ -222,7 +223,7 @@ function TRARC(
222223
compute_direction(nlp_stop, PData, workspace, ∇f, norm_∇f, α, solve_model)
223224

224225
slope = ∇f d
225-
Δq = compute_Δq(workspace, nlp_at_x.Hx, d, ∇f) # -(∇f + 0.5 * (nlp_at_x.Hx * d)) ⋅ d
226+
Δq = compute_Δq(workspace, Hx, d, ∇f)
226227

227228
xtnext .= xt .+ d
228229
ftnext = obj(nlp, xtnext, workspace)
@@ -238,7 +239,8 @@ function TRARC(
238239
unsucc += 1
239240
unsuccinarow += 1
240241
η = (1 - acceptance_threshold) / 10 # ∈ (acceptance_threshold, 1)
241-
qksk = ft + slope + ((nlp_at_x.Hx * d) d) / 2
242+
mul!(workspace.Hd, Hx, d)
243+
qksk = ft + slope + (workspace.Hd d) / 2
242244
αbad = (1 - η) * slope / ((1 - η) * (ft + slope) + η * qksk - ftnext)
243245
α = min(decrease(PData, α, TR), max(TR.large_decrease_factor, αbad) * α)
244246
elseif r < acceptance_threshold # unsucessful
@@ -276,7 +278,8 @@ function TRARC(
276278
nlp_stop.meta.nb_of_stop = iter
277279
Stopping._smart_update!(nlp_at_x, x = xt, fx = ft, gx = ∇f)
278280
OK = stop!(nlp_stop)
279-
success && Stopping._smart_update!(nlp_at_x, Hx = hessian!(workspace, nlp, xt))
281+
Hx = hessian!(workspace, nlp, xt)
282+
success && Stopping._smart_update!(nlp_at_x, Hx = Hx)
280283
end # while !OK
281284

282285
return nlp_stop

src/utils/hessian_rep.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,28 @@
22
HessDense(::AbstractNLPModel{T,S}, n)
33
Return a structure used for the evaluation of dense Hessian matrix.
44
"""
5-
struct HessDense
5+
struct HessDense{T}
6+
H::Matrix{T}
67
function HessDense(::AbstractNLPModel{T,S}, n) where {T,S}
7-
return new()
8+
H = Matrix{Float64}(undef, n, n)
9+
return new{T}(H)
810
end
911
end
1012

1113
"""
1214
HessSparse(::AbstractNLPModel{T,S}, n)
1315
Return a structure used for the evaluation of sparse Hessian matrix.
1416
"""
15-
struct HessSparse{S,Vi}
17+
struct HessSparse{T,S,Vi,It<:Integer}
1618
rows::Vi
1719
cols::Vi
1820
vals::S
21+
H::Symmetric{T, SparseMatrixCSC{T, It}}
1922
function HessSparse(nlp::AbstractNLPModel{T,S}, n) where {T,S}
2023
rows, cols = hess_structure(nlp)
2124
vals = S(undef, nlp.meta.nnzh)
22-
return new{S,typeof(rows)}(rows, cols, vals)
25+
H = Symmetric(spzeros(T, n, n), :L)
26+
return new{T,S,typeof(rows),eltype(rows)}(rows, cols, vals, H)
2327
end
2428
end
2529

@@ -42,22 +46,26 @@ end
4246
HessOp(::AbstractNLPModel{T,S}, n)
4347
Return a structure used for the evaluation of the Hessian matrix as an operator.
4448
"""
45-
struct HessOp{S}
49+
mutable struct HessOp{S}
4650
Hv::S
51+
H
4752
function HessOp(::AbstractNLPModel{T,S}, n) where {T,S}
48-
return new{S}(S(undef, n))
53+
H = LinearOperator{T}(n, n, true, true, v -> v, v -> v, v -> v)
54+
return new{S}(S(undef, n), H)
4955
end
5056
end
5157

5258
"""
5359
HessGaussNewtonOp(::AbstractNLSModel{T,S}, n)
5460
Return a structure used for the evaluation of the Hessian matrix as an operator.
5561
"""
56-
struct HessGaussNewtonOp{S}
62+
mutable struct HessGaussNewtonOp{S}
5763
Jv::S
5864
Jtv::S
65+
H
5966
function HessGaussNewtonOp(nls::AbstractNLSModel{T,S}, n) where {T,S}
60-
return new{S}(S(undef, nls.nls_meta.nequ), S(undef, n))
67+
Jx = LinearOperator{T}(nls.nls_meta.nequ, n, false, false, v -> v, v -> v, v -> v)
68+
return new{S}(S(undef, nls.nls_meta.nequ), S(undef, n), Jx' * Jx)
6169
end
6270
end
6371

@@ -81,23 +89,26 @@ Return the Hessian matrix of `nlp` at `x` in-place with memory update of `worksp
8189
function hessian! end
8290

8391
function hessian!(workspace::HessDense, nlp, x)
84-
H = Matrix(hess(nlp, x))
85-
return H
92+
workspace.H .= Matrix(hess(nlp, x))
93+
return workspace.H
8694
end
8795

8896
function hessian!(workspace::HessOp, nlp, x)
89-
return hess_op!(nlp, x, workspace.Hv)
97+
workspace.H = hess_op!(nlp, x, workspace.Hv)
98+
return workspace.H
9099
end
91100

92101
function hessian!(workspace::HessGaussNewtonOp, nlp, x)
93102
Jx = jac_op_residual!(nlp, x, workspace.Jv, workspace.Jtv)
94-
return Jx' * Jx
103+
workspace.H = Jx' * Jx
104+
return workspace.H
95105
end
96106

97107
function hessian!(workspace::HessSparse, nlp, x)
98108
hess_coord!(nlp, x, workspace.vals)
99109
n = nlp.meta.nvar
100-
return Symmetric(sparse(workspace.rows, workspace.cols, workspace.vals, n, n), :L)
110+
workspace.H .= Symmetric(sparse(workspace.rows, workspace.cols, workspace.vals, n, n), :L)
111+
return workspace.H
101112
end
102113

103114
function hessian!(workspace::HessSparseCOO, nlp, x)

0 commit comments

Comments
 (0)