Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions src/blas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ end
"- length(B)=$(length(B))\n" *
"- length(C)=$(length(C))\n"))
end
m, k, n = (-1, -1, -1)
for (i, (As, Bs, Cs)) in enumerate(zip(A, B, C))
m, k = size(As, transA == 'N' ? 1 : 2), size(As, transA == 'N' ? 2 : 1)
n, g = size(Bs, transB == 'N' ? 2 : 1), size(Bs, transB == 'N' ? 1 : 2)
Expand All @@ -653,7 +654,7 @@ end
lda = max(1, stride(A[1], 2))
ldb = max(1, stride(B[1], 2))
ldc = max(1, stride(C[1], 2))
m, k, n, lda, ldb, ldc
return m, k, n, lda, ldb, ldc
end

## (GE) general matrix-matrix multiplication batched
Expand All @@ -666,15 +667,17 @@ for (fname, elty) in
@eval begin
function gemm_batched!(
transA::Char, transB::Char,
alpha::($elty), A::ROCArray{$elty, 3},
B::ROCArray{$elty, 3}, beta::($elty), C::ROCArray{$elty, 3},
)
alpha::($elty), A::TA,
B::TB, beta::($elty), C::TC,
) where {TA<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
TB<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
TC<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}}
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
transA, transB, A, B, C)

batch_count = size(C, 3)
a_broadcast = (size(A, 3) == 1) && (batch_count > 1)
b_broadcast = (size(B, 3) == 1) && (batch_count > 1)
batch_count = C isa ROCArray ? size(C, 3) : length(C)
a_broadcast = A isa ROCArray && (size(A, 3) == 1) && (batch_count > 1)
b_broadcast = B isa ROCArray && (size(B, 3) == 1) && (batch_count > 1)
Ab = a_broadcast ? device_batch(A, batch_count) : device_batch(A)
Bb = b_broadcast ? device_batch(B, batch_count) : device_batch(B)
Cb = device_batch(C)
Expand All @@ -684,18 +687,18 @@ for (fname, elty) in
handle, transA, transB,
m, n, k, Ref(alpha), Ab, lda, Bb, ldb, Ref(beta),
Cb, ldc, batch_count)
C
return C
end
function gemm_batched(
transA::Char, transB::Char, alpha::($elty), A::T, B::K,
) where {
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
}
is_ab_vec = Int(T <: Vector) + Int(K <: Vector)
(is_ab_vec != 0) && (is_ab_vec != 2) && throw(ArgumentError(
"If `A` is a `Vector{ROCMatrix}`, then `B` must be too."))
if T isa Vector
if T <: Vector
C = ROCMatrix{$elty}[similar(B[i], $elty, (
size(A[i], transA == 'N' ? 1 : 2),
size(B[i], transB == 'N' ? 2 : 1))) for i in 1:length(A)]
Expand All @@ -704,13 +707,15 @@ for (fname, elty) in
k = size(B, transB == 'N' ? 2 : 1)
C = similar(A, $elty, (m, k, max(size(A, 3), size(B, 3))))
end
gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
m, k, n, lda, ldb, ldc = _check_gemm_batched_dims(
transA, transB, A, B, C)
return gemm_batched!(transA, transB, alpha, A, B, zero($elty), C)
end
function gemm_batched(transA::Char, transB::Char, A::T, B::K) where {
T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}},
T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}},
}
gemm_batched(transA, transB, one($elty), A, B)
return gemm_batched(transA, transB, one($elty), A, B)
end
end
end
Expand Down Expand Up @@ -1029,7 +1034,7 @@ for (fname, elty) in
@eval begin
function trsm_batched!(
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
)
if( length(A) != length(B) )
throw(DimensionMismatch(""))
Expand All @@ -1051,7 +1056,7 @@ for (fname, elty) in
end
function trsm_batched(
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1},
A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1},
)
trsm_batched!(side, uplo, transa, diag, alpha, A, copy(B) )
end
Expand Down
57 changes: 57 additions & 0 deletions test/rocarray/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ end
@test testf((a, b) -> f(TR(a)) * b, A, x)
@test testf((a, b) -> lmul!(f(TR(a)), b), A, copy(x))
end
@testset "trmv" begin
A = rand(T, m, m)
dA = ROCArray(A)
x = rand(T, m)
dx = ROCArray(x)
dy = rocBLAS.trmv('U', 'N', 'N', dA, dx)
@test collect(dy) ≈ triu(A) * x
end

A, x = rand(T, m, m), rand(T, m)
@testset "Triangular ldiv" for TR in (
Expand All @@ -157,6 +165,14 @@ end
@test testf((a, b) -> f(TR(a)) \ b, A, x)
@test testf((a, b) -> ldiv!(f(TR(a)), b), A, copy(x))
end
@testset "trsv" begin
A = rand(T, m, m)
dA = ROCArray(A)
x = rand(T, m)
dx = ROCArray(x)
dy = rocBLAS.trsv('U', 'N', 'N', dA, dx)
@test collect(dy) ≈ triu(A) \ x
end

x = rand(T, m, m)
@testset "inv($TR)" for TR in (
Expand Down Expand Up @@ -372,6 +388,25 @@ end
(a, b) -> b / adjtype(uplotype(a)),
triu(rand(T, m, m)), rand(T, n, m))
end
@testset "trsm" begin
A = rand(T, m, m)
dA = ROCArray(A)
B = rand(T, m, m)
dB = ROCArray(B)
dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB)
@test collect(dC) ≈ triu(A) \ B
end
@testset "trsm_batched" begin
batch_count = 3
A = [rand(T, m, m) for ix in 1:batch_count]
dA = [ROCArray(A_) for A_ in A]
B = [rand(T, m, m) for ix in 1:batch_count]
dB = [ROCArray(B_) for B_ in B]
dC = rocBLAS.trsm_batched('L', 'U', 'N', 'N', one(T), dA, dB)
for ix in 1:batch_count
@test collect(dC[ix]) ≈ triu(A[ix]) \ B[ix]
end
end

@testset "triangular-dense mul ($T, $adjtype, $uplotype)" for adjtype in (
identity, adjoint, transpose,
Expand All @@ -389,6 +424,14 @@ end
(c, a, b) -> mul!(c, b, adjtype(uplotype(a))),
zeros(T, n, m), A, rand(T, n, m))
end
@testset "trmm" begin
A = rand(T, m, m)
dA = ROCArray(A)
B = rand(T, m, m)
dB = ROCArray(B)
dC = rocBLAS.trmm('L', 'U', 'N', 'N', one(T), dA, dB)
@test collect(dC) ≈ triu(A) * B
end

@testset "triangular-triangular mul" for (TRa, ta, TRb, tb) in (
(UpperTriangular, identity, LowerTriangular, identity),
Expand Down Expand Up @@ -452,6 +495,20 @@ end
(bt == 'T' ? transpose(B[:, :, i]) : B[:, :, i])
@test C[:, :, i] ≈ c
end
A = [rand(T, 4, 4) for ix in 1:batch_count]
B = [rand(T, 4, 4) for ix in 1:batch_count]
RA = [ROCArray(A_) for A_ in A]
RB = [ROCArray(B_) for B_ in B]

RC = rocBLAS.gemm_batched(at, bt, RA, RB)
@test length(RC) == batch_count
C = [Array(RC_) for RC_ in RC]
for i in 1:batch_count
c =
(at == 'T' ? transpose(A[i]) : A[i]) *
(bt == 'T' ? transpose(B[i]) : B[i])
@test C[i] ≈ c
end
end
end
end
Expand Down