Skip to content

Commit b616614

Browse files
author
Katharine Hyatt
committed
Fixes and tests for banded methods
1 parent c2f4472 commit b616614

File tree

3 files changed

+126
-3
lines changed

3 files changed

+126
-3
lines changed

src/blas/util.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
# convert matrix to band storage
2+
function band(A::AbstractMatrix,kl,ku)
3+
m, n = size(A)
4+
AB = zeros(eltype(A),kl+ku+1,n)
5+
for j = 1:n
6+
for i = max(1,j-ku):min(m,j+kl)
7+
AB[ku+1-j+i,j] = A[i,j]
8+
end
9+
end
10+
return AB
11+
end
12+
13+
# convert band storage to general matrix
14+
function unband(AB::AbstractMatrix,m,kl,ku)
15+
bm, n = size(AB)
16+
A = zeros(eltype(AB),m,n)
17+
for j = 1:n
18+
for i = max(1,j-ku):min(m,j+kl)
19+
A[i,j] = AB[ku+1-j+i,j]
20+
end
21+
end
22+
return A
23+
end
24+
25+
# zero out elements not on matrix bands
26+
function bandex(A::AbstractMatrix,kl,ku)
27+
m, n = size(A)
28+
AB = band(A,kl,ku)
29+
B = unband(AB,m,kl,ku)
30+
return B
31+
end
32+
133
const ROCBLASReal = Union{Float32, Float64}
234
const ROCBLASComplex = Union{ComplexF32, ComplexF64}
335
const ROCBLASFloat = Union{ROCBLASReal, ROCBLASComplex}

src/blas/wrappers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,10 @@ for (fname, elty) in ((:rocblas_stbmv,:Float32),
368368
x
369369
end
370370
function tbmv(
371-
uplo::Char, trans::Char, diag::Char,
371+
uplo::Char, trans::Char, diag::Char, k::Integer,
372372
A::ROCMatrix{$elty}, x::ROCVector{$elty},
373373
)
374-
tbmv!(uplo, trans, diag, A, copy(x))
374+
tbmv!(uplo, trans, diag, k, A, copy(x))
375375
end
376376
end
377377
end

test/rocarray/blas.jl

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
using AMDGPU.rocBLAS
44
using AMDGPU.HIP
5-
import .rocBLAS: rocblas_int
5+
import .rocBLAS: rocblas_int, bandex, band
66

77
m = 20
88
n = 35
@@ -231,6 +231,97 @@ end
231231
@test B hB
232232
end
233233
end
234+
@testset "Banded methods" begin
235+
ku = 3
236+
kl = 4
237+
A = rand(T, m, n) # generate banded matrix
238+
A = bandex(A, kl, ku)
239+
Ab = band(A, kl, ku) # get packed format
240+
d_Ab = ROCArray(Ab)
241+
x = rand(T, n)
242+
d_x = ROCArray(x)
243+
alpha = rand(T)
244+
beta = rand(T)
245+
y = rand(T, m)
246+
d_y = ROCArray(y)
247+
@testset "gbmv!" begin
248+
@testset for (op, da1, da2, ha1, ha2) in (('N',d_x,d_y,x,y), ('T',d_y,d_x,y,x), ('C',d_y,d_x,y,x))
249+
# test y = alpha*A*x + beta*y
250+
rocBLAS.gbmv!(op,m,kl,ku,alpha,d_Ab,da1,beta,da2)
251+
BLAS.gbmv!(op,m,kl,ku,alpha,Ab,ha1,beta,ha2)
252+
@test ha2 Array(da2)
253+
d_y = ROCArray(y)
254+
end
255+
end
256+
@testset "gbmv" begin
257+
d_x = ROCArray(x)
258+
d_y = rocBLAS.gbmv('N',m,kl,ku,alpha,d_Ab,d_x)
259+
y = BLAS.gbmv('N',m,kl,ku,alpha,Ab,x)
260+
@test y Array(d_y)
261+
# test alpha=1 version without y
262+
d_y = rocBLAS.gbmv('N',m,kl,ku,d_Ab,d_x)
263+
y = BLAS.gbmv('N',m,kl,ku,Ab,x)
264+
@test y Array(d_y)
265+
end
266+
A = rand(T, m, m)
267+
A = A + A'
268+
nbands = 3
269+
@test m >= 1 + nbands
270+
A = bandex(A,nbands,nbands)
271+
# convert to 'upper' banded storage format
272+
AB = band(A, 0, nbands)
273+
# construct x
274+
x = rand(T, m)
275+
d_AB = ROCArray(AB)
276+
d_x = ROCArray(x)
277+
y = rand(T, m)
278+
if T <: Real
279+
@testset "sbmv!" begin
280+
d_y = ROCArray(y)
281+
rocBLAS.sbmv!('U',nbands,alpha,d_AB,d_x,beta,d_y)
282+
@test alpha*(A*x) + beta*y Array(d_y)
283+
end
284+
@testset "sbmv" begin
285+
d_y = rocBLAS.sbmv('U',nbands,d_AB,d_x)
286+
@test A*x Array(d_y)
287+
end
288+
else
289+
@testset "hbmv!" begin
290+
d_y = ROCArray(y)
291+
rocBLAS.hbmv!('U',nbands,alpha,d_AB,d_x,beta,d_y)
292+
@test alpha*(A*x) + beta*y Array(d_y)
293+
end
294+
@testset "hbmv" begin
295+
d_y = rocBLAS.hbmv('U',nbands,d_AB,d_x)
296+
@test A*x Array(d_y)
297+
end
298+
end
299+
# generate banded triangular matrix
300+
A = rand(T, m, m)
301+
nbands = 3 # restrict to 3 bands
302+
@test m >= 1 + nbands
303+
A = bandex(A,0,nbands)
304+
AB = band(A,0,nbands)
305+
d_AB = ROCArray(AB)
306+
@testset "tbmv!" begin
307+
d_y = ROCArray(y)
308+
rocBLAS.tbmv!('U','N','N',nbands,d_AB,d_y)
309+
@test A*y Array(d_y)
310+
end
311+
@testset "tbmv" begin
312+
d_y = rocBLAS.tbmv('U','N','N',nbands,d_AB,d_x)
313+
@test A*x Array(d_y)
314+
end
315+
@testset "tbsv!" begin
316+
d_y = copy(d_x)
317+
rocBLAS.tbsv!('U','N','N',nbands,d_AB,d_y)
318+
@test A\x Array(d_y)
319+
end
320+
@testset "tbsv" begin
321+
d_y = rocBLAS.tbsv('U','N','N',nbands,d_AB,d_x)
322+
@test A\x Array(d_y)
323+
end
324+
end
234325
end
235326
end
236327

0 commit comments

Comments
 (0)