Skip to content

Commit 6c0b3c8

Browse files
authored
Reuse the buffer for LAPACK routines on CPU (#940)
1 parent 13a96b5 commit 6c0b3c8

File tree

7 files changed

+191
-20
lines changed

7 files changed

+191
-20
lines changed

src/Krylov.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module Krylov
22

33
using LinearAlgebra, SparseArrays, Printf
4+
import LinearAlgebra.BLAS: BlasInt, @blasfunc, libblastrampoline
45

56
include("krylov_stats.jl")
67

src/block_gmres.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto
130130
ΔX, X, W, V, Z = workspace.ΔX, workspace.X, workspace.W, workspace.V, workspace.Z
131131
C, D, R, H, τ, stats = workspace.C, workspace.D, workspace.R, workspace.H, workspace.τ, workspace.stats
132132
Ψtmp = C
133+
buffer = workspace.buffer
133134
warm_start = workspace.warm_start
134135
RNorms = stats.residuals
135136
reset!(stats)
@@ -209,7 +210,7 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto
209210

210211
# Initial Γ and V₁
211212
copyto!(V[1], R₀)
212-
householder!(V[1], Z[1], τ[1])
213+
householder!(V[1], Z[1], τ[1], buffer)
213214

214215
npass = npass + 1
215216
inner_iter = 0
@@ -249,27 +250,27 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto
249250
end
250251

251252
# Vₖ₊₁ and Ψₖ₊₁.ₖ are stored in Q and C.
252-
householder!(Q, C, τ[inner_iter])
253+
householder!(Q, C, τ[inner_iter], buffer)
253254

254255
# Update the QR factorization of Hₖ₊₁.ₖ.
255256
# Apply previous Householder reflections Ωᵢ.
256257
for i = 1 : inner_iter-1
257258
D1 .= R[nr+i]
258259
D2 .= R[nr+i+1]
259-
kormqr!('L', trans, H[i], τ[i], D)
260+
kormqr!('L', trans, H[i], τ[i], D, buffer)
260261
R[nr+i] .= D1
261262
R[nr+i+1] .= D2
262263
end
263264

264265
# Compute and apply current Householder reflection Ωₖ.
265266
H[inner_iter][1:p,:] .= R[nr+inner_iter]
266267
H[inner_iter][p+1:2p,:] .= C
267-
householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true)
268+
householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], buffer, compact=true)
268269

269270
# Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁)
270271
D1 .= Z[inner_iter]
271272
D2 .= zero(FC)
272-
kormqr!('L', trans, H[inner_iter], τ[inner_iter], D)
273+
kormqr!('L', trans, H[inner_iter], τ[inner_iter], D, buffer)
273274
Z[inner_iter] .= D1
274275

275276
# Update residual norm estimate.

src/block_krylov_utils.jl

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ end
181181
# Output :
182182
# Q an n-by-k orthonormal matrix: QᴴQ = Iₖ
183183
# R an k-by-k upper triangular matrix: QR = A
184-
function householder(A::AbstractMatrix{FC}; compact::Bool=false) where FC <: FloatOrComplex
184+
function householder(A::Matrix{FC}; compact::Bool=false) where FC <: FloatOrComplex
185185
n, k = size(A)
186186
Q = copy(A)
187187
τ = zeros(FC, k)
@@ -197,3 +197,105 @@ function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::Abstract
197197
!compact && korgqr!(Q, τ)
198198
return Q, R
199199
end
200+
201+
function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}, buffer::AbstractVector{FC}; compact::Bool=false) where FC <: FloatOrComplex
202+
n, k = size(Q)
203+
kfill!(R, zero(FC))
204+
kgeqrf!(Q, τ, buffer)
205+
copy_triangle(Q, R, k)
206+
!compact && korgqr!(Q, τ, buffer)
207+
return Q, R
208+
end
209+
210+
for (Xgeqrf, Xorgqr, Xormqr, T) in ((:sgeqrf_, :sorgqr_, :sormqr_, :Float32 ),
211+
(:dgeqrf_, :dorgqr_, :dormqr_, :Float64 ),
212+
(:cgeqrf_, :cungqr_, :cunmqr_, :ComplexF32),
213+
(:zgeqrf_, :zungqr_, :zunmqr_, :ComplexF64))
214+
@eval begin
215+
function $Xgeqrf(m, n, a, lda, tau, work, lwork, info)
216+
return ccall((@blasfunc($Xgeqrf), libblastrampoline), Cvoid,
217+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{$T}, Ref{BlasInt},
218+
Ptr{$T}, Ptr{$T}, Ref{BlasInt}, Ref{BlasInt}),
219+
m, n, a, lda, tau, work, lwork, info)
220+
end
221+
222+
function kgeqrf_buffer!(A::Matrix{$T}, tau::Vector{$T})
223+
m, n = size(A)
224+
work = Ref{$T}(0)
225+
lda = max(1, stride(A, 2))
226+
$Xgeqrf(m, n, A, lda, tau, work, -1, 0)
227+
return work[] |> BlasInt
228+
end
229+
230+
function kgeqrf!(A::Matrix{$T}, tau::Vector{$T}, work::Vector{$T})
231+
m, n = size(A)
232+
lwork = length(work)
233+
lda = max(1, stride(A, 2))
234+
$Xgeqrf(m, n, A, lda, tau, work, lwork, 0)
235+
return nothing
236+
end
237+
238+
function $Xorgqr(m, n, k, a, lda, tau, work, lwork, info)
239+
return ccall((@blasfunc($Xorgqr), libblastrampoline), Cvoid,
240+
(Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$T},
241+
Ref{BlasInt}, Ptr{$T}, Ptr{$T}, Ref{BlasInt}, Ref{BlasInt}),
242+
m, n, k, a, lda, tau, work, lwork, info)
243+
end
244+
245+
function korgqr_buffer!(A::Matrix{$T}, tau::Vector{$T})
246+
m, n = size(A)
247+
k = length(tau)
248+
work = Ref{$T}(0)
249+
lda = max(1, stride(A, 2))
250+
$Xorgqr(m, n, k, A, lda, tau, work, -1, 0)
251+
return work[] |> BlasInt
252+
end
253+
254+
function korgqr!(A::Matrix{$T}, tau::Vector{$T}, work::Vector{$T})
255+
symb = @blasfunc($Xorgqr)
256+
m, n = size(A)
257+
k = length(tau)
258+
lwork = length(work)
259+
lda = max(1, stride(A, 2))
260+
$Xorgqr(m, n, k, A, lda, tau, work, lwork, 0)
261+
return nothing
262+
end
263+
264+
function $Xormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, info)
265+
return ccall((@blasfunc($Xormqr), libblastrampoline), Cvoid,
266+
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$T},
267+
Ref{BlasInt}, Ptr{$T}, Ptr{$T}, Ref{BlasInt}, Ptr{$T}, Ref{BlasInt},
268+
Ref{BlasInt}, Clong, Clong),
269+
side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, info, 1, 1)
270+
end
271+
272+
function kormqr_buffer!(side::Char, trans::Char, A::Matrix{$T}, tau::Vector{$T}, C::Matrix{$T})
273+
m, n = size(A)
274+
k = length(tau)
275+
work = Ref{$T}(0)
276+
lda = max(1, stride(A, 2))
277+
ldc = max(1, stride(C, 2))
278+
$Xormqr(side, trans, m, n, k, A, lda, tau, C, ldc, work, -1, 0)
279+
return work[] |> BlasInt
280+
end
281+
282+
function kormqr!(side::Char, trans::Char, A::Matrix{$T}, tau::Vector{$T}, C::Matrix{$T}, work::Vector{$T})
283+
m, n = size(A)
284+
k = length(tau)
285+
lwork = length(work)
286+
lda = max(1, stride(A, 2))
287+
ldc = max(1, stride(C, 2))
288+
$Xormqr(side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, 0)
289+
return nothing
290+
end
291+
end
292+
end
293+
294+
kgeqrf!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.geqrf!(A, tau)
295+
kgeqrf!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}, buffer:: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.geqrf!(A, tau)
296+
297+
korgqr!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.orgqr!(A, tau)
298+
korgqr!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}, buffer:: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.orgqr!(A, tau)
299+
300+
kormqr!(side :: Char, trans :: Char, A :: AbstractMatrix{T}, tau :: AbstractVector{T}, C :: AbstractMatrix{T}) where T <: BLAS.BlasFloat = LAPACK.ormqr!(side, trans, A, tau, C)
301+
kormqr!(side :: Char, trans :: Char, A :: AbstractMatrix{T}, tau :: AbstractVector{T}, C :: AbstractMatrix{T}, buffer:: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.ormqr!(side, trans, A, tau, C)

src/block_krylov_workspaces.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ mutable struct BlockMinresWorkspace{T,FC,SV,SM} <: BlockKrylovWorkspace{T,FC,SV,
3030
Hₖ₋₁ :: SM
3131
τₖ₋₂ :: SV
3232
τₖ₋₁ :: SV
33+
buffer :: Vector{FC}
3334
warm_start :: Bool
3435
stats :: SimpleStats{T}
3536
end
@@ -55,7 +56,11 @@ function BlockMinresWorkspace(m::Integer, n::Integer, p::Integer, SV::Type, SM::
5556
SV = isconcretetype(SV) ? SV : typeof(τₖ₋₁)
5657
SM = isconcretetype(SM) ? SM : typeof(X)
5758
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
58-
workspace = BlockMinresWorkspace{T,FC,SV,SM}(m, n, p, ΔX, X, P, Q, C, D, Φ, Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, Hₖ₋₂, Hₖ₋₁, τₖ₋₂, τₖ₋₁, false, stats)
59+
size_buffer = C isa Matrix ? max(kgeqrf_buffer!(Vₖ, τₖ₋₁), kgeqrf_buffer!(Hₖ₋₁, τₖ₋₁),
60+
korgqr_buffer!(Vₖ, τₖ₋₁), korgqr_buffer!(Hₖ₋₁, τₖ₋₁),
61+
kormqr_buffer!('L', FC <: AbstractFloat ? 'T' : 'C', Hₖ₋₁, τₖ₋₁, D)) : 0
62+
buffer = SV(undef, size_buffer)
63+
workspace = BlockMinresWorkspace{T,FC,SV,SM}(m, n, p, ΔX, X, P, Q, C, D, Φ, Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, Hₖ₋₂, Hₖ₋₁, τₖ₋₂, τₖ₋₁, buffer, false, stats)
5964
return workspace
6065
end
6166

@@ -93,6 +98,7 @@ mutable struct BlockGmresWorkspace{T,FC,SV,SM} <: BlockKrylovWorkspace{T,FC,SV,S
9398
R :: Vector{SM}
9499
H :: Vector{SM}
95100
τ :: Vector{SV}
101+
buffer :: Vector{FC}
96102
warm_start :: Bool
97103
stats :: SimpleStats{T}
98104
end
@@ -115,8 +121,12 @@ function BlockGmresWorkspace(m::Integer, n::Integer, p::Integer, SV::Type, SM::T
115121
τ = SV[SV(undef, p) for i = 1 : memory]
116122
SV = isconcretetype(SV) ? SV : typeof(τ)
117123
SM = isconcretetype(SM) ? SM : typeof(X)
124+
size_buffer = C isa Matrix ? max(kgeqrf_buffer!(V[1], τ[1]), kgeqrf_buffer!(H[1], τ[1]),
125+
korgqr_buffer!(V[1], τ[1]), korgqr_buffer!(H[1], τ[1]),
126+
kormqr_buffer!('L', FC <: AbstractFloat ? 'T' : 'C', H[1], τ[1], D)) : 0
127+
buffer = SV(undef, size_buffer)
118128
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
119-
workspace = BlockGmresWorkspace{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, false, stats)
129+
workspace = BlockGmresWorkspace{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, buffer, false, stats)
120130
return workspace
121131
end
122132

src/block_minres.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
119119
wₖ₋₂, wₖ₋₁ = workspace.wₖ₋₂, workspace.wₖ₋₁
120120
Hₖ₋₂, Hₖ₋₁ = workspace.Hₖ₋₂, workspace.Hₖ₋₁
121121
τₖ₋₂, τₖ₋₁ = workspace.τₖ₋₂, workspace.τₖ₋₁
122+
buffer = workspace.buffer
122123
warm_start = workspace.warm_start
123124
RNorms = stats.residuals
124125
reset!(stats)
@@ -173,7 +174,7 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
173174
# Initial Ψ₁ and V₁
174175
τ = τₖ₋₂
175176
copyto!(Vₖ, R₀)
176-
householder!(Vₖ, Φbarₖ, τ)
177+
householder!(Vₖ, Φbarₖ, τ, buffer)
177178

178179
while !(solved || tired || user_requested_exit || overtimed)
179180
# Update iteration index.
@@ -205,7 +206,7 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
205206
if iter 3
206207
D1 .= zero(T)
207208
D2 .= Ψₖ'
208-
kormqr!('L', trans, Hₖ₋₂, τₖ₋₂, D)
209+
kormqr!('L', trans, Hₖ₋₂, τₖ₋₂, D, buffer)
209210
Πₖ₋₂ .= D1
210211
Γbarₖ₋₁ .= D2
211212
end
@@ -215,27 +216,27 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
215216
(iter == 2) && (Γbarₖ₋₁ .= Ψₖ')
216217
D1 .= Γbarₖ₋₁
217218
D2 .= Ωₖ
218-
kormqr!('L', trans, Hₖ₋₁, τₖ₋₁, D)
219+
kormqr!('L', trans, Hₖ₋₁, τₖ₋₁, D, buffer)
219220
Γₖ₋₁ .= D1
220221
Λbarₖ .= D2
221222
end
222223

223224
# Vₖ₊₁ and Ψₖ₊₁ are stored in Q and Ψₖ₊₁.
224225
τ = τₖ₋₂
225-
householder!(Q, Ψₖ₊₁, τ)
226+
householder!(Q, Ψₖ₊₁, τ, buffer)
226227

227228
# Compute and apply current Householder reflection θₖ.
228229
Hₖ = Hₖ₋₂
229230
τₖ = τₖ₋₂
230231
(iter == 1) && (Λbarₖ .= Ωₖ)
231232
Hₖ[1:p,:] .= Λbarₖ
232233
Hₖ[p+1:2p,:] .= Ψₖ₊₁
233-
householder!(Hₖ, Λₖ, τₖ, compact=true)
234+
householder!(Hₖ, Λₖ, τₖ, buffer, compact=true)
234235

235236
# Update Zₖ = (Qₖ)ᴴΨ₁E₁ = (Φ₁, ..., Φₖ, Φbarₖ₊₁)
236237
D1 .= Φbarₖ
237238
D2 .= zero(FC)
238-
kormqr!('L', trans, Hₖ, τₖ, D)
239+
kormqr!('L', trans, Hₖ, τₖ, D, buffer)
239240
Φₖ .= D1
240241

241242
# Compute the directions Wₖ, the last columns of Wₖ = Vₖ(Rₖ)⁻¹ ⟷ (Rₖ)ᵀ(Wₖ)ᵀ = (Vₖ)ᵀ

src/krylov_utils.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,6 @@ kfill!(x :: AbstractArray{T}, val :: T) where T <: FloatOrComplex = fill!(x, val
348348

349349
kref!(n, x, y, c, s) = reflect!(x, y, c, s)
350350

351-
kgeqrf!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.geqrf!(A, tau)
352-
korgqr!(A :: AbstractMatrix{T}, tau :: AbstractVector{T}) where T <: BLAS.BlasFloat = LAPACK.orgqr!(A, tau)
353-
kormqr!(side :: Char, trans :: Char, A :: AbstractMatrix{T}, tau :: AbstractVector{T}, C :: AbstractMatrix{T}) where T <: BLAS.BlasFloat = LAPACK.ormqr!(side, trans, A, tau, C)
354-
355351
macro kswap!(x, y)
356352
quote
357353
local tmp = $(esc(x))

test/test_allocations.jl

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
A = FC.(get_div_grad(18, 18, 18)) # Dimension m x n
77
m,n = size(A)
8+
p = 5
89
k = div(n, 2)
910
Au = A[1:k,:] # Dimension k x n
1011
Ao = A[:,1:k] # Dimension m x k
1112
b = Ao * ones(FC, k) # Dimension m
1213
c = Au * ones(FC, n) # Dimension k
14+
B = A * Matrix{FC}(I, m, p) # Dimension m × p
1315
mem = 200
1416

1517
T = real(FC)
@@ -697,8 +699,8 @@
697699
storage_gpmr_bytes(mem, m, n) = nbits_FC * ((mem + 2) * (n + m) + mem * (2 * mem + 7)) + nbits_T * 4 * mem
698700

699701
expected_gpmr_bytes = storage_gpmr_bytes(mem, m, k)
700-
gpmr(Ao, Au, b, c, memory=mem, itmax=mem) # warmup
701-
actual_gpmr_bytes = @allocated gpmr(Ao, Au, b, c, memory=mem, itmax=mem)
702+
gpmr(Ao, Au, b, c; memory=mem, itmax=mem) # warmup
703+
actual_gpmr_bytes = @allocated gpmr(Ao, Au, b, c; memory=mem, itmax=mem)
702704
if VERSION < v"1.11.5" || !Sys.isapple()
703705
@test expected_gpmr_bytes actual_gpmr_bytes 1.02 * expected_gpmr_bytes
704706
end
@@ -708,6 +710,64 @@
708710
inplace_gpmr_bytes = @allocated gpmr!(workspace, Ao, Au, b, c)
709711
@test inplace_gpmr_bytes == 0
710712
end
713+
714+
@testset "BLOCK-GMRES" begin
715+
# BLOCK-GMRES needs:
716+
# - 2 (n*p)-matrices: X, W
717+
# - 1 (p*p)-matrix: C
718+
# - 1 (2p*p)-matrix: D
719+
# - mem p-vectors: τ
720+
# - mem (n*p)-matrices: V
721+
# - mem (p*p)-matrices: Z
722+
# - mem*(mem+1)/2 (p*p)-matrices: R
723+
# - mem (2p*p)-matrices: H
724+
# - lwork-vector: buffer
725+
function storage_block_gmres_bytes(mem, n, p)
726+
res = (2*n*p + p*p + 2p*p + mem*p + mem*n*p + mem*p*p + mem*(mem+1)*p*p/2 + mem*2p*p)
727+
return nbits_FC * res
728+
end
729+
730+
expected_block_gmres_bytes = storage_block_gmres_bytes(mem, n, p)
731+
block_gmres(A, B; memory=mem, itmax=mem) # warmup
732+
actual_block_gmres_bytes = @allocated block_gmres(A, B; memory=mem, itmax=mem)
733+
if VERSION < v"1.11.5" || !Sys.isapple()
734+
@test expected_block_gmres_bytes actual_block_gmres_bytes 1.08 * expected_block_gmres_bytes
735+
end
736+
737+
workspace = BlockGmresWorkspace(A, B; memory=mem)
738+
block_gmres!(workspace, A, B) # warmup
739+
inplace_block_gmres_bytes = @allocated block_gmres!(workspace, A, B)
740+
@test inplace_block_gmres_bytes == 0
741+
end
742+
743+
@testset "BLOCK-MINRES" begin
744+
# BLOCK-MINRES needs:
745+
# - 2 (n*p)-matrices: X, W
746+
# - 1 (p*p)-matrix: C
747+
# - 1 (2p*p)-matrix: D
748+
# - mem p-vectors: τ
749+
# - mem (n*p)-matrices: V
750+
# - mem (p*p)-matrices: Z
751+
# - mem*(mem+1)/2 (p*p)-matrices: R
752+
# - mem (2p*p)-matrices: H
753+
# - lwork-vector: buffer
754+
function storage_block_minres_bytes(mem, n, p)
755+
res = (2*n*p + p*p + 2p*p + mem*p + mem*n*p + mem*p*p + mem*(mem+1)*p*p/2 + mem*2p*p)
756+
return nbits_FC * res
757+
end
758+
759+
expected_block_minres_bytes = storage_block_minres_bytes(mem, n, p)
760+
block_minres(A, B) # warmup
761+
# actual_block_minres_bytes = @allocated block_minres(A, B)
762+
# if VERSION < v"1.11.5" || !Sys.isapple()
763+
# @test expected_block_minres_bytes ≤ actual_block_minres_bytes ≤ 1.08 * expected_block_minres_bytes
764+
# end
765+
766+
# Workspace = BlockMinresWorkspace(A, B)
767+
# block_minres!(Workspace, A, B) # warmup
768+
# inplace_block_minres_bytes = @allocated block_minres!(Workspace, A, B)
769+
# @test inplace_block_minres_bytes == 0
770+
end
711771
end
712772
end
713773
end

0 commit comments

Comments
 (0)