Skip to content

Commit dccabf0

Browse files
authored
Merge pull request #832
Even more BLAS tests and fixes
2 parents eb928bc + e34093d commit dccabf0

File tree

2 files changed

+79
-17
lines changed

2 files changed

+79
-17
lines changed

src/blas/wrappers.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ end
639639
"- length(B)=$(length(B))\n" *
640640
"- length(C)=$(length(C))\n"))
641641
end
642+
m, k, n = (-1, -1, -1)
642643
for (i, (As, Bs, Cs)) in enumerate(zip(A, B, C))
643644
m, k = size(As, transA == 'N' ? 1 : 2), size(As, transA == 'N' ? 2 : 1)
644645
n, g = size(Bs, transB == 'N' ? 2 : 1), size(Bs, transB == 'N' ? 1 : 2)
@@ -653,7 +654,7 @@ end
653654
lda = max(1, stride(A[1], 2))
654655
ldb = max(1, stride(B[1], 2))
655656
ldc = max(1, stride(C[1], 2))
656-
m, k, n, lda, ldb, ldc
657+
return m, k, n, lda, ldb, ldc
657658
end
658659

659660
## (GE) general matrix-matrix multiplication batched
@@ -666,15 +667,17 @@ for (fname, elty) in
666667
@eval begin
667668
function gemm_batched!(
668669
transA::Char, transB::Char,
669-
alpha::($elty), A::ROCArray{$elty, 3},
670-
B::ROCArray{$elty, 3}, beta::($elty), C::ROCArray{$elty, 3},
671-
)
670+
alpha::($elty), A::TA,
671+
B::TB, beta::($elty), C::TC,
672+
) where {TA<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
673+
TB<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
674+
TC<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}}
672675
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
673676
transA, transB, A, B, C)
674677

675-
batch_count = size(C, 3)
676-
a_broadcast = (size(A, 3) == 1) && (batch_count > 1)
677-
b_broadcast = (size(B, 3) == 1) && (batch_count > 1)
678+
batch_count = C isa ROCArray ? size(C, 3) : length(C)
679+
a_broadcast = A isa ROCArray && (size(A, 3) == 1) && (batch_count > 1)
680+
b_broadcast = B isa ROCArray && (size(B, 3) == 1) && (batch_count > 1)
678681
Ab = a_broadcast ? device_batch(A, batch_count) : device_batch(A)
679682
Bb = b_broadcast ? device_batch(B, batch_count) : device_batch(B)
680683
Cb = device_batch(C)
@@ -684,18 +687,18 @@ for (fname, elty) in
684687
handle, transA, transB,
685688
m, n, k, Ref(alpha), Ab, lda, Bb, ldb, Ref(beta),
686689
Cb, ldc, batch_count)
687-
C
690+
return C
688691
end
689692
function gemm_batched(
690693
transA::Char, transB::Char, alpha::($elty), A::T, B::K,
691694
) where {
692-
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
693-
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
695+
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
696+
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
694697
}
695698
is_ab_vec = Int(T <: Vector) + Int(K <: Vector)
696699
(is_ab_vec != 0) && (is_ab_vec != 2) && throw(ArgumentError(
697700
"If `A` is a `Vector{ROCMatrix}`, then `B` must be too."))
698-
if T isa Vector
701+
if T <: Vector
699702
C = ROCMatrix{$elty}[similar(B[i], $elty, (
700703
size(A[i], transA == 'N' ? 1 : 2),
701704
size(B[i], transB == 'N' ? 2 : 1))) for i in 1:length(A)]
@@ -704,13 +707,15 @@ for (fname, elty) in
704707
k = size(B, transB == 'N' ? 2 : 1)
705708
C = similar(A, $elty, (m, k, max(size(A, 3), size(B, 3))))
706709
end
707-
gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
710+
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
711+
transA, transB, A, B, C)
712+
return gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
708713
end
709714
function gemm_batched(transA::Char, transB::Char, A::T, B::K) where {
710-
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
711-
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
715+
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
716+
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
712717
}
713-
gemm_batched(transA, transB, one($elty), A, B)
718+
return gemm_batched(transA, transB, one($elty), A, B)
714719
end
715720
end
716721
end
@@ -1029,7 +1034,7 @@ for (fname, elty) in
10291034
@eval begin
10301035
function trsm_batched!(
10311036
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1032-
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
1037+
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
10331038
)
10341039
if( length(A) != length(B) )
10351040
throw(DimensionMismatch(""))
@@ -1051,7 +1056,7 @@ for (fname, elty) in
10511056
end
10521057
function trsm_batched(
10531058
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1054-
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
1059+
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
10551060
)
10561061
trsm_batched!(side, uplo, transa, diag, alpha, A, copy(B) )
10571062
end

test/rocarray/blas.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ end
147147
@test testf((a, b) -> f(TR(a)) * b, A, x)
148148
@test testf((a, b) -> lmul!(f(TR(a)), b), A, copy(x))
149149
end
150+
@testset "trmv" begin
151+
A = rand(T, m, m)
152+
dA = ROCArray(A)
153+
x = rand(T, m)
154+
dx = ROCArray(x)
155+
dy = rocBLAS.trmv('U', 'N', 'N', dA, dx)
156+
@test collect(dy) triu(A) * x
157+
end
150158

151159
A, x = rand(T, m, m), rand(T, m)
152160
@testset "Triangular ldiv" for TR in (
@@ -157,6 +165,14 @@ end
157165
@test testf((a, b) -> f(TR(a)) \ b, A, x)
158166
@test testf((a, b) -> ldiv!(f(TR(a)), b), A, copy(x))
159167
end
168+
@testset "trsv" begin
169+
A = rand(T, m, m)
170+
dA = ROCArray(A)
171+
x = rand(T, m)
172+
dx = ROCArray(x)
173+
dy = rocBLAS.trsv('U', 'N', 'N', dA, dx)
174+
@test collect(dy) triu(A) \ x
175+
end
160176

161177
x = rand(T, m, m)
162178
@testset "inv($TR)" for TR in (
@@ -372,6 +388,25 @@ end
372388
(a, b) -> b / adjtype(uplotype(a)),
373389
triu(rand(T, m, m)), rand(T, n, m))
374390
end
391+
@testset "trsm" begin
392+
A = rand(T, m, m)
393+
dA = ROCArray(A)
394+
B = rand(T, m, m)
395+
dB = ROCArray(B)
396+
dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB)
397+
@test collect(dC) triu(A) \ B
398+
end
399+
@testset "trsm_batched" begin
400+
batch_count = 3
401+
A = [rand(T, m, m) for ix in 1:batch_count]
402+
dA = [ROCArray(A_) for A_ in A]
403+
B = [rand(T, m, m) for ix in 1:batch_count]
404+
dB = [ROCArray(B_) for B_ in B]
405+
dC = rocBLAS.trsm_batched('L', 'U', 'N', 'N', one(T), dA, dB)
406+
for ix in 1:batch_count
407+
@test collect(dC[ix]) triu(A[ix]) \ B[ix]
408+
end
409+
end
375410

376411
@testset "triangular-dense mul ($T, $adjtype, $uplotype)" for adjtype in (
377412
identity, adjoint, transpose,
@@ -389,6 +424,14 @@ end
389424
(c, a, b) -> mul!(c, b, adjtype(uplotype(a))),
390425
zeros(T, n, m), A, rand(T, n, m))
391426
end
427+
@testset "trmm" begin
428+
A = rand(T, m, m)
429+
dA = ROCArray(A)
430+
B = rand(T, m, m)
431+
dB = ROCArray(B)
432+
dC = rocBLAS.trmm('L', 'U', 'N', 'N', one(T), dA, dB)
433+
@test collect(dC) triu(A) * B
434+
end
392435

393436
@testset "triangular-triangular mul" for (TRa, ta, TRb, tb) in (
394437
(UpperTriangular, identity, LowerTriangular, identity),
@@ -452,6 +495,20 @@ end
452495
(bt == 'T' ? transpose(B[:, :, i]) : B[:, :, i])
453496
@test C[:, :, i] c
454497
end
498+
A = [rand(T, 4, 4) for ix in 1:batch_count]
499+
B = [rand(T, 4, 4) for ix in 1:batch_count]
500+
RA = [ROCArray(A_) for A_ in A]
501+
RB = [ROCArray(B_) for B_ in B]
502+
503+
RC = rocBLAS.gemm_batched(at, bt, RA, RB)
504+
@test length(RC) == batch_count
505+
C = [Array(RC_) for RC_ in RC]
506+
for i in 1:batch_count
507+
c =
508+
(at == 'T' ? transpose(A[i]) : A[i]) *
509+
(bt == 'T' ? transpose(B[i]) : B[i])
510+
@test C[i] c
511+
end
455512
end
456513
end
457514
end

0 commit comments

Comments
 (0)