1- mutable struct Model{S,T,CT,AT} <: NLPModels.AbstractNLPModel{T,Vector{T}}
1+ mutable struct Model{S,T,CT,AT,JTB } <: NLPModels.AbstractNLPModel{T,Vector{T}}
22 model:: LRO.Model{T,CT,AT}
33 dim:: Dimensions{S}
44 meta:: NLPModels.NLPModelMeta{T,Vector{T}}
55 counters:: NLPModels.Counters
6+ jtprod_buffer:: JTB
67 function Model{S}(model:: LRO.Model{T,CT,AT} , ranks) where {S,T,CT,AT}
78 dim = Dimensions{S}(model, ranks)
8- return new{S,T,CT,AT}(
9+ jtprod_buffer = buffer_for_jtprod(model, dim)
10+ return new{S,T,CT,AT,typeof(jtprod_buffer)}(
911 model,
1012 dim,
1113 meta(dim, LRO. cons_constant(model)),
1214 NLPModels. Counters(),
15+ jtprod_buffer,
1316 )
1417 end
1518end
@@ -42,6 +45,7 @@ function set_rank!(model::Model, i::LRO.MatrixIndex, r)
4245 set_rank!(model. dim, i, r)
4346 # `nvar` has changed so we need to reset `model.meta`
4447 model. meta = meta(model. dim, model. meta. lcon)
48+ model. jtprod_buffer = buffer_for_jtprod(model. model, model. dim)
4549 return
4650end
4751
@@ -71,7 +75,8 @@ function grad!(
7175 i:: LRO.MatrixIndex ,
7276)
7377 C = LRO. grad(model. model, i)
74- LinearAlgebra. mul!(G. factor, C, X. factor)
78+ buffer = _buffer(model. jtprod_buffer[i. value], C, X. factor)
79+ LRO. buffered_mul!(G. factor, C, X. factor, true , false , buffer)
7580 G. factor .*= 2
7681 return
7782end
@@ -152,16 +157,64 @@ function jtprod!(
152157 return JtV
153158end
154159
160+ const _RankOne{T} = LRO. AbstractFactorization{T,<: AbstractVector{T} }
161+ const _LowRank{T} = LRO. AbstractFactorization{T,<: AbstractMatrix{T} }
162+
163+ function buffer_for_jtprod(
164+ model:: LRO.Model{T} ,
165+ dim:: Dimensions ,
166+ i:: LRO.MatrixIndex ,
167+ ) where {T}
168+ row = view(model. A, i. value, :)
169+ C = model. C[i. value]
170+ if any(A -> A isa _LowRank, row) || C isa _LowRank
171+ ncols = maximum(row; init = 0 ) do A
172+ if A isa _LowRank
173+ return LRO. max_rank(A)
174+ else
175+ return 0
176+ end
177+ end
178+ if C isa _LowRank
179+ ncols = max(ncols, LRO. max_rank(C))
180+ end
181+ return zeros(T, dim. ranks[i. value], ncols)
182+ elseif any(A -> A isa _RankOne, row) || C isa _RankOne
183+ return zeros(T, dim. ranks[i. value])
184+ end
185+ return
186+ end
187+
188+ function buffer_for_jtprod(model:: LRO.Model , dim:: Dimensions )
189+ return buffer_for_jtprod.(model, dim, LRO. matrix_indices(model))
190+ end
191+
192+ _buffer(_, :: AbstractMatrix , _) = nothing
193+ _buffer(buffer:: AbstractVector , :: _RankOne , :: AbstractMatrix ) = buffer
194+ function _buffer(buffer:: AbstractMatrix , A:: _LowRank , :: AbstractMatrix )
195+ # FIXME with this if-else, we return a small Union but the compiler
196+ # since to forget about this Union and allocates later
197+ # if size(buffer, 2) == LRO.max_rank(A)
198+ # buffer
199+ # else
200+ # Using this `view` instead of `buffer`, `AllocCheck` now
201+ # sees possible allocations but `@allocated` sees none
202+ return view(buffer, :, Base. OneTo(LRO. max_rank(A)))
203+ # end
204+ end
205+
155206function add_jtprod!(
156207 model:: Model ,
157208 X:: LRO.Factorization ,
158209 y:: AbstractVector ,
159210 JtV:: LRO.Factorization ,
160211 i:: LRO.MatrixIndex ,
212+ α = 2 ,
161213)
162214 for j in eachindex(y)
163215 A = LRO. jac(model. model, j, i)
164- LinearAlgebra. mul!(JtV. factor, A, X. factor, 2 y[j], true )
216+ buffer = _buffer(model. jtprod_buffer[i. value], A, X. factor)
217+ LRO. buffered_mul!(JtV. factor, A, X. factor, α * y[j], true , buffer)
165218 end
166219end
167220
@@ -185,8 +238,10 @@ function NLPModels.jtprod!(
185238 X = Solution(x, model. dim)
186239 JtV = Solution(Jtv, model. dim)
187240 jtprod!(model, X, y, LRO. left_factor(JtV, LRO. ScalarIndex), LRO. ScalarIndex)
188- for i in LRO. matrix_indices(model. model)
189- jtprod!(model, X[i], y, JtV[i], i)
241+ for i:: LRO.MatrixIndex in LRO. matrix_indices(model. model)
242+ Xi = X[i]
243+ JtVi = JtV[i]
244+ jtprod!(model, Xi, y, JtVi, i)
190245 end
191246 return Jtv
192247end
@@ -248,14 +303,11 @@ function NLPModels.hprod!(
248303 obj_weight,
249304 )
250305 for i in LRO. matrix_indices(model. model)
251- Vi = V[i]. factor
252- C = LRO. grad(model. model, i)
253- Hvi = HV[i]. factor
254- LinearAlgebra. mul!(Hvi, C, Vi, 2 obj_weight, false )
255- for j in 1 : model. meta. ncon
256- A = LRO. jac(model. model, j, i)
257- LinearAlgebra. mul!(Hvi, A, Vi, - 2 y[j], true )
258- end
306+ Vi = V[i]
307+ Hvi = HV[i]
308+ grad!(model, Vi, Hvi, i)
309+ Hvi. factor .*= obj_weight
310+ add_jtprod!(model, Vi, y, Hvi, i, - 2 )
259311 end
260312 return Hv
261313end
0 commit comments