diff --git a/src/blas/wrappers.jl b/src/blas/wrappers.jl index 8258816c3..54ed17f16 100644 --- a/src/blas/wrappers.jl +++ b/src/blas/wrappers.jl @@ -639,6 +639,7 @@ end "- length(B)=$(length(B))\n" * "- length(C)=$(length(C))\n")) end + m, k, n = (-1, -1, -1) for (i, (As, Bs, Cs)) in enumerate(zip(A, B, C)) m, k = size(As, transA == 'N' ? 1 : 2), size(As, transA == 'N' ? 2 : 1) n, g = size(Bs, transB == 'N' ? 2 : 1), size(Bs, transB == 'N' ? 1 : 2) @@ -653,7 +654,7 @@ end lda = max(1, stride(A[1], 2)) ldb = max(1, stride(B[1], 2)) ldc = max(1, stride(C[1], 2)) - m, k, n, lda, ldb, ldc + return m, k, n, lda, ldb, ldc end ## (GE) general matrix-matrix multiplication batched @@ -666,15 +667,17 @@ for (fname, elty) in @eval begin function gemm_batched!( transA::Char, transB::Char, - alpha::($elty), A::ROCArray{$elty, 3}, - B::ROCArray{$elty, 3}, beta::($elty), C::ROCArray{$elty, 3}, - ) + alpha::($elty), A::TA, + B::TB, beta::($elty), C::TC, + ) where {TA<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}, + TB<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}, + TC<:Union{ROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}} m, k, n, lda, ldb, ldc = _check_gemm_batched_dims( transA, transB, A, B, C) - batch_count = size(C, 3) - a_broadcast = (size(A, 3) == 1) && (batch_count > 1) - b_broadcast = (size(B, 3) == 1) && (batch_count > 1) + batch_count = C isa ROCArray ? size(C, 3) : length(C) + a_broadcast = A isa ROCArray && (size(A, 3) == 1) && (batch_count > 1) + b_broadcast = B isa ROCArray && (size(B, 3) == 1) && (batch_count > 1) Ab = a_broadcast ? device_batch(A, batch_count) : device_batch(A) Bb = b_broadcast ? device_batch(B, batch_count) : device_batch(B) Cb = device_batch(C) @@ -684,18 +687,18 @@ for (fname, elty) in handle, transA, transB, m, n, k, Ref(alpha), Ab, lda, Bb, ldb, Ref(beta), Cb, ldc, batch_count) - C + return C end function gemm_batched( transA::Char, transB::Char, alpha::($elty), A::T, B::K, ) where { - T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}}, - K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}}, + T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}, + K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}, } is_ab_vec = Int(T <: Vector) + Int(K <: Vector) (is_ab_vec != 0) && (is_ab_vec != 2) && throw(ArgumentError( "If `A` is a `Vector{ROCMatrix}`, then `B` must be too.")) - if T isa Vector + if T <: Vector C = ROCMatrix{$elty}[similar(B[i], $elty, ( size(A[i], transA == 'N' ? 1 : 2), size(B[i], transB == 'N' ? 2 : 1))) for i in 1:length(A)] @@ -704,13 +707,15 @@ for (fname, elty) in k = size(B, transB == 'N' ? 2 : 1) C = similar(A, $elty, (m, k, max(size(A, 3), size(B, 3)))) end - gemm_batched!(transA, transB, alpha, A, B, zero($elty), C) + m, k, n, lda, ldb, ldc = _check_gemm_batched_dims( + transA, transB, A, B, C) + return gemm_batched!(transA, transB, alpha, A, B, zero($elty), C) end function gemm_batched(transA::Char, transB::Char, A::T, B::K) where { - T <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}}, - K <: Union{AnyROCArray{$elty, 3}, Vector{ROCMatrix{$elty}}}, + T <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}, + K <: Union{AnyROCArray{$elty, 3}, Vector{<:ROCMatrix{$elty}}}, } - gemm_batched(transA, transB, one($elty), A, B) + return gemm_batched(transA, transB, one($elty), A, B) end end end @@ -1029,7 +1034,7 @@ for (fname, elty) in @eval begin function trsm_batched!( side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty), - A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1}, + A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1}, ) if( length(A) != length(B) ) throw(DimensionMismatch("")) @@ -1051,7 +1056,7 @@ for (fname, elty) in end function trsm_batched( side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty), - A::Array{ROCMatrix{$elty},1}, B::Array{ROCMatrix{$elty},1}, + A::Array{<:ROCMatrix{$elty},1}, B::Array{<:ROCMatrix{$elty},1}, ) trsm_batched!(side, uplo, transa, diag, alpha, A, copy(B) ) end diff --git a/test/rocarray/blas.jl b/test/rocarray/blas.jl index e2f9d92d0..d22923b6f 100644 --- a/test/rocarray/blas.jl +++ b/test/rocarray/blas.jl @@ -147,6 +147,14 @@ end @test testf((a, b) -> f(TR(a)) * b, A, x) @test testf((a, b) -> lmul!(f(TR(a)), b), A, copy(x)) end + @testset "trmv" begin + A = rand(T, m, m) + dA = ROCArray(A) + x = rand(T, m) + dx = ROCArray(x) + dy = rocBLAS.trmv('U', 'N', 'N', dA, dx) + @test collect(dy) ≈ triu(A) * x + end A, x = rand(T, m, m), rand(T, m) @testset "Triangular ldiv" for TR in ( @@ -157,6 +165,14 @@ end @test testf((a, b) -> f(TR(a)) \ b, A, x) @test testf((a, b) -> ldiv!(f(TR(a)), b), A, copy(x)) end + @testset "trsv" begin + A = rand(T, m, m) + dA = ROCArray(A) + x = rand(T, m) + dx = ROCArray(x) + dy = rocBLAS.trsv('U', 'N', 'N', dA, dx) + @test collect(dy) ≈ triu(A) \ x + end x = rand(T, m, m) @testset "inv($TR)" for TR in ( @@ -372,6 +388,25 @@ end (a, b) -> b / adjtype(uplotype(a)), triu(rand(T, m, m)), rand(T, n, m)) end + @testset "trsm" begin + A = rand(T, m, m) + dA = ROCArray(A) + B = rand(T, m, m) + dB = ROCArray(B) + dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB) + @test collect(dC) ≈ triu(A) \ B + end + @testset "trsm_batched" begin + batch_count = 3 + A = [rand(T, m, m) for ix in 1:batch_count] + dA = [ROCArray(A_) for A_ in A] + B = [rand(T, m, m) for ix in 1:batch_count] + dB = [ROCArray(B_) for B_ in B] + dC = rocBLAS.trsm_batched('L', 'U', 'N', 'N', one(T), dA, dB) + for ix in 1:batch_count + @test collect(dC[ix]) ≈ triu(A[ix]) \ B[ix] + end + end @testset "triangular-dense mul ($T, $adjtype, $uplotype)" for adjtype in ( identity, adjoint, transpose, @@ -389,6 +424,14 @@ end (c, a, b) -> mul!(c, b, adjtype(uplotype(a))), zeros(T, n, m), A, rand(T, n, m)) end + @testset "trmm" begin + A = rand(T, m, m) + dA = ROCArray(A) + B = rand(T, m, m) + dB = ROCArray(B) + dC = rocBLAS.trmm('L', 'U', 'N', 'N', one(T), dA, dB) + @test collect(dC) ≈ triu(A) * B + end @testset "triangular-triangular mul" for (TRa, ta, TRb, tb) in ( (UpperTriangular, identity, LowerTriangular, identity), @@ -452,6 +495,20 @@ end (bt == 'T' ? transpose(B[:, :, i]) : B[:, :, i]) @test C[:, :, i] ≈ c end + A = [rand(T, 4, 4) for ix in 1:batch_count] + B = [rand(T, 4, 4) for ix in 1:batch_count] + RA = [ROCArray(A_) for A_ in A] + RB = [ROCArray(B_) for B_ in B] + + RC = rocBLAS.gemm_batched(at, bt, RA, RB) + @test length(RC) == batch_count + C = [Array(RC_) for RC_ in RC] + for i in 1:batch_count + c = + (at == 'T' ? transpose(A[i]) : A[i]) * + (bt == 'T' ? transpose(B[i]) : B[i]) + @test C[i] ≈ c + end end end end