Skip to content

Commit 3220bcf

Browse files
authored
Merge pull request #831
A few more blas tests and fix
2 parents 4cb0eec + b616614 commit 3220bcf

File tree

3 files changed

+384
-16
lines changed

3 files changed

+384
-16
lines changed

src/blas/util.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
# convert matrix to band storage
2+
function band(A::AbstractMatrix,kl,ku)
3+
m, n = size(A)
4+
AB = zeros(eltype(A),kl+ku+1,n)
5+
for j = 1:n
6+
for i = max(1,j-ku):min(m,j+kl)
7+
AB[ku+1-j+i,j] = A[i,j]
8+
end
9+
end
10+
return AB
11+
end
12+
13+
# convert band storage to general matrix
14+
function unband(AB::AbstractMatrix,m,kl,ku)
15+
bm, n = size(AB)
16+
A = zeros(eltype(AB),m,n)
17+
for j = 1:n
18+
for i = max(1,j-ku):min(m,j+kl)
19+
A[i,j] = AB[ku+1-j+i,j]
20+
end
21+
end
22+
return A
23+
end
24+
25+
# zero out elements not on matrix bands
26+
function bandex(A::AbstractMatrix,kl,ku)
27+
m, n = size(A)
28+
AB = band(A,kl,ku)
29+
B = unband(AB,m,kl,ku)
30+
return B
31+
end
32+
133
const ROCBLASReal = Union{Float32, Float64}
234
const ROCBLASComplex = Union{ComplexF32, ComplexF64}
335
const ROCBLASFloat = Union{ROCBLASReal, ROCBLASComplex}

src/blas/wrappers.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,10 @@ for (fname, elty) in ((:rocblas_stbmv,:Float32),
368368
x
369369
end
370370
function tbmv(
371-
uplo::Char, trans::Char, diag::Char,
371+
uplo::Char, trans::Char, diag::Char, k::Integer,
372372
A::ROCMatrix{$elty}, x::ROCVector{$elty},
373373
)
374-
tbmv!(uplo, trans, diag, A, copy(x))
374+
tbmv!(uplo, trans, diag, k, A, copy(x))
375375
end
376376
end
377377
end
@@ -496,10 +496,10 @@ for (fname, elty) in ((:rocblas_dsyr,:Float64),
496496
end
497497

498498
### her
499-
for (fname, elty) in ((:rocblas_zher,:ComplexF64),
500-
(:rocblas_cher,:ComplexF32))
499+
for (fname, elty, relty) in ((:rocblas_zher,:ComplexF64,:Float64),
500+
(:rocblas_cher,:ComplexF32,:Float32))
501501
@eval begin
502-
function her!(uplo::Char, alpha::$elty, x::ROCVector{$elty}, A::ROCMatrix{$elty})
502+
function her!(uplo::Char, alpha::$relty, x::ROCVector{$elty}, A::ROCMatrix{$elty})
503503
m, n = size(A)
504504
m == n || throw(DimensionMismatch("Matrix A is $m by $n but must be square"))
505505
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
@@ -863,12 +863,12 @@ for (fname, elty) in ((:rocblas_zhemm,:ComplexF64),
863863
end
864864

865865
## herk
866-
for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
867-
(:rocblas_cherk,:ComplexF32))
866+
for (fname, elty, relty) in ((:rocblas_zherk,:ComplexF64,:Float64),
867+
(:rocblas_cherk,:ComplexF32,:Float32))
868868
@eval begin
869869
function herk!(
870-
uplo::Char, trans::Char, alpha::($elty), A::ROCVecOrMat{$elty},
871-
beta::($elty), C::ROCMatrix{$elty},
870+
uplo::Char, trans::Char, alpha::($relty), A::ROCVecOrMat{$elty},
871+
beta::($relty), C::ROCMatrix{$elty},
872872
)
873873
mC, n = size(C)
874874
if mC != n throw(DimensionMismatch("C must be square")) end
@@ -881,12 +881,12 @@ for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
881881
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, Ref(beta), C, ldc)
882882
C
883883
end
884-
function herk(uplo::Char, trans::Char, alpha::($elty), A::ROCVecOrMat{$elty})
884+
function herk(uplo::Char, trans::Char, alpha::($relty), A::ROCVecOrMat{$elty})
885885
n = size(A, trans == 'N' ? 1 : 2)
886-
herk!(uplo, trans, alpha, A, zero($elty), similar(A, $elty, (n,n)))
886+
herk!(uplo, trans, alpha, A, zero($relty), similar(A, $elty, (n,n)))
887887
end
888888
herk(uplo::Char, trans::Char, A::ROCVecOrMat{$elty}) =
889-
herk(uplo, trans, one($elty), A)
889+
herk(uplo, trans, one($relty), A)
890890
end
891891
end
892892

@@ -1092,13 +1092,13 @@ for (fname, elty) in ((:rocblas_dgeam,:Float64),
10921092
)
10931093
m,n = size(B)
10941094
if ((transb == 'T' || transb == 'C'))
1095-
geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (n,m) ) )
1095+
return geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (n,m) ) )
10961096
end
10971097
if (transb == 'N')
1098-
geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (m,n) ) )
1098+
return geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (m,n) ) )
10991099
end
11001100
end
1101-
geam( uplo::Char, trans::Char, A::ROCMatrix{$elty}, B::ROCMatrix{$elty}) = geam( uplo, trans, one($elty), A, one($elty), B)
1101+
geam( transa::Char, transb::Char, A::ROCMatrix{$elty}, B::ROCMatrix{$elty}) = geam( transa, transb, one($elty), A, one($elty), B)
11021102
end
11031103
end
11041104

0 commit comments

Comments
 (0)