diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 6fe8eb05..35b0d01e 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -1,6 +1,8 @@ #include "onemkl.h" #include "sycl.hpp" - +#include +#include +#include #include // This is a workaround to flush MKL submissions into Level-zero queue, using @@ -51,6 +53,90 @@ oneapi::mkl::side convert(onemklSide val) { } } +template +class gemmBatchInfo { + public: + int64_t *m_mbuf = nullptr; + int64_t *m_nbuf = nullptr; + int64_t *m_kbuf = nullptr; + int64_t *m_ldabuf = nullptr; + int64_t *m_ldbbuf = nullptr; + int64_t *m_ldcbuf = nullptr; + oneapi::mkl::transpose *m_transa = nullptr; + oneapi::mkl::transpose *m_transb = nullptr; + T *m_alphabuf = nullptr; + T *m_betabuf = nullptr; + int64_t *m_group_size = nullptr; + sycl::device m_device; + sycl::context m_context; + oneapi::mkl::transpose m_ta; + oneapi::mkl::transpose m_tb; + + // Constructor + gemmBatchInfo(syclQueue_t device_queue, + int64_t group_count, + onemklTranspose transa, + onemklTranspose transb, + int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, + T alpha, T beta) { + // Get device and context info from device_queue + auto main_queue = device_queue->val; + m_device = main_queue.get_device(); + m_context = main_queue.get_context(); + try { + // Allocate uniform arrays of m,n,k,lda,ldb,ldc,alpha,beta + // group_size and transpose_a, transpose_b supporting oneMKL + // gemm_batch API + m_mbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_nbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_kbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_alphabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); + m_betabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); + m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), + m_device, m_context); + m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), + m_device, m_context); + m_ta = convert(transa); + m_tb = convert(transb); + m_group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + } catch(const std::bad_alloc& e) { + std::cerr << "Error: " << e.what() << std::endl; + } + for (int i = 0; i < group_count; i++) { + m_mbuf[i] = m; + m_nbuf[i] = n; + m_kbuf[i] = k; + m_ldabuf[i] = lda; + m_ldbbuf[i] = ldb; + m_ldcbuf[i] = ldc; + m_alphabuf[i] = alpha; + m_betabuf[i] = beta; + m_transa[i] = m_ta; + m_transb[i] = m_tb; + m_group_size[i] = 1; + } + }; + + // Destructor + ~gemmBatchInfo() { + free(m_mbuf, m_context); + free(m_nbuf, m_context); + free(m_kbuf, m_context); + free(m_ldabuf, m_context); + free(m_ldbbuf, m_context); + free(m_ldcbuf, m_context); + free(m_alphabuf, m_context); + free(m_betabuf, m_context); + free(m_transa, m_context); + free(m_transb, m_context); + free(m_group_size, m_context); + } +}; + extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, onemklTranspose transB, int64_t m, int64_t n, int64_t k, uint16_t alpha, const short *A, int64_t lda, @@ -122,6 +208,124 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, return 0; } +extern "C" void onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, uint16_t alpha, + const short **a, int64_t lda, const short **b, + int64_t ldb, uint16_t beta, short **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, sycl::bit_cast(alpha), + sycl::bit_cast(beta)); + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + reinterpret_cast(&a[0]), &gemmInfo.m_ldabuf[0], + reinterpret_cast(&b[0]), &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], reinterpret_cast(&c[0]), + &gemmInfo.m_ldcbuf[0], group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); + +} + +extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float alpha, + const float **a, int64_t lda, const float **b, + int64_t ldb, float beta, float **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + (const float **)&a[0], &gemmInfo.m_ldabuf[0], + (const float **)&b[0], &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], &c[0], &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); + +} + +extern "C" void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double alpha, + const double **a, int64_t lda, const double **b, + int64_t ldb, double beta, double **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + (const double **)&a[0], &gemmInfo.m_ldabuf[0], + (const double **)&b[0], &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], &c[0], &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); + +} + +extern "C" void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float _Complex alpha, + const float _Complex **a, int64_t lda, + const float _Complex **b, + int64_t ldb, float _Complex beta, float _Complex **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo> gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + reinterpret_cast **>(&a[0]), + &gemmInfo.m_ldabuf[0], + reinterpret_cast **>(&b[0]), + &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], + reinterpret_cast **>(&c[0]), &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); + +} + +extern "C" void onemklZgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double _Complex alpha, + const double _Complex **a, int64_t lda, + const double _Complex **b, + int64_t ldb, double _Complex beta, + double _Complex **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo> gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + reinterpret_cast **>(&a[0]), + &gemmInfo.m_ldabuf[0], + reinterpret_cast **>(&b[0]), + &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], + reinterpret_cast **>(&c[0]), &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); +} + extern "C" void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 9d2c9bec..2953ec7d 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -59,6 +59,45 @@ int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, const short *B, int64_t ldb, uint16_t beta, short *C, int64_t ldc); +void onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, uint16_t alpha, + const short **a, int64_t lda, const short **b, + int64_t ldb, uint16_t beta, short **c, + int64_t ldc, int64_t group_count); + +void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float alpha, + const float **a, int64_t lda, const float **b, + int64_t ldb, float beta, float **c, + int64_t ldc, int64_t group_count); + +void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double alpha, + const double **a, int64_t lda, const double **b, + int64_t ldb, double beta, double **c, + int64_t ldc, int64_t group_count); + +void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float _Complex alpha, + const float _Complex **a, int64_t lda, + const float _Complex **b, + int64_t ldb, float _Complex beta, + float _Complex **c, int64_t ldc, + int64_t group_count); + +void onemklZgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double _Complex alpha, + const double _Complex **a, int64_t lda, + const double _Complex **b, + int64_t ldb, double _Complex beta, + double _Complex **c, int64_t ldc, + int64_t group_count); + void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/lib/mkl/oneMKL.jl b/lib/mkl/oneMKL.jl index f985101a..f1827ebf 100644 --- a/lib/mkl/oneMKL.jl +++ b/lib/mkl/oneMKL.jl @@ -1,7 +1,7 @@ module oneMKL using ..oneAPI - +using ..oneAPI: unsafe_free! using ..oneL0 using ..Support diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 9f47281a..ef9ae284 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -44,6 +44,69 @@ function Base.convert(::Type{onemklDiag}, diag::Char) end end +# create a batch of pointers in device memory from a batch of device arrays +@inline function unsafe_batch(batch::Vector{<:oneArray{T}}) where {T} + ptrs = pointer.(batch) + return oneArray(ptrs) +end + +## (GE) general matrix-matrix multiplication batched +for (fname, elty) in + ((:onemklDgemmBatched,:Float64), + (:onemklSgemmBatched,:Float32), + (:onemklHgemmBatched,:Float16), + (:onemklCgemmBatched,:ComplexF32), + (:onemklZgemmBatched,:ComplexF64)) + @eval begin + function gemm_batched!(transA::Char, + transB::Char, + alpha::Number, + A::Vector{<:oneStridedMatrix{$elty}}, + B::Vector{<:oneStridedMatrix{$elty}}, + beta::Number, + C::Vector{<:oneStridedMatrix{$elty}}) + if length(A) != length(B) || length(A) != length(C) + throw(DimensionMismatch("")) + end + for (As,Bs,Cs) in zip(A,B,C) + m = size(As, transA == 'N' ? 1 : 2) + k = size(As, transA == 'N' ? 2 : 1) + n = size(Bs, transB == 'N' ? 2 : 1) + if m != size(Cs,1) || n != size(Cs,2) || k != size(Bs, transB == 'N' ? 1 : 2) + throw(DimensionMismatch("")) + end + end + + m = size(A[1], transA == 'N' ? 1 : 2) + k = size(A[1], transA == 'N' ? 2 : 1) + n = size(B[1], transB == 'N' ? 2 : 1) + lda = max(1,stride(A[1],2)) + ldb = max(1,stride(B[1],2)) + ldc = max(1,stride(C[1],2)) + Aptrs = unsafe_batch(A) + Bptrs = unsafe_batch(B) + Cptrs = unsafe_batch(C) + queue = global_queue(context(A[1]), device(A[1])) + $fname(sycl_queue(queue), transA, transB, m, n, k, alpha, Aptrs, lda, Bptrs, + ldb, beta, Cptrs, ldc, length(A)) + unsafe_free!(Cptrs) + unsafe_free!(Bptrs) + unsafe_free!(Aptrs) + C + end + end +end + +function gemm_batched(transA::Char, transB::Char, alpha::Number, + A::Vector{<:oneStridedMatrix{T}}, B::Vector{<:oneStridedMatrix{T}}) where T + C = oneMatrix{T}[similar(B[1], (size(A[1], transA == 'N' ? 1 : 2),size(B[1], transB == 'N' ? 2 : 1))) for i in 1:length(A)] + gemm_batched!(transA, transB, alpha, A, B, zero(T), C ) +end +function gemm_batched(transA::Char, transB::Char, + A::Vector{<:oneStridedMatrix{T}}, B::Vector{<:oneStridedMatrix{T}}) where T + gemm_batched(transA, transB, one(T), A, B) +end + ## (L3: symm) symmetric matrix-matrix and matrix-vector multiplication for (fname, elty) in ((:onemklSsymm, :Float32), (:onemklDsymm, :Float64), diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index c895f0be..c9646c02 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -140,6 +140,66 @@ function onemklHgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld beta::Float16, C::ZePtr{Float16}, ldc::Int64)::Cint end +function onemklHgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklHgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::Float16, + a::ZePtr{Ptr{Float16}}, lda::Int64, + b::ZePtr{Ptr{Float16}}, ldb::Int64, + beta::Float16, c::ZePtr{Ptr{Float16}}, + ldc::Int64, group_count::Int64)::Cvoid +end + +function onemklSgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklSgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::Cfloat, + a::ZePtr{Ptr{Cfloat}}, lda::Int64, + b::ZePtr{Ptr{Cfloat}}, ldb::Int64, + beta::Cfloat, c::ZePtr{Ptr{Cfloat}}, + ldc::Int64, group_count::Int64)::Cvoid +end + +function onemklDgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklDgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::Cdouble, + a::ZePtr{Ptr{Cdouble}}, lda::Int64, + b::ZePtr{Ptr{Cdouble}}, ldb::Int64, + beta::Cdouble, c::ZePtr{Ptr{Cdouble}}, + ldc::Int64, group_count::Int64)::Cvoid +end + +function onemklCgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklCgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::ComplexF32, + a::ZePtr{Ptr{ComplexF32}}, lda::Int64, + b::ZePtr{Ptr{ComplexF32}}, ldb::Int64, + beta::ComplexF32, c::ZePtr{Ptr{ComplexF32}}, + ldc::Int64, group_count::Int64)::Cvoid +end + +function onemklZgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklZgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::ComplexF64, + a::ZePtr{Ptr{ComplexF64}}, lda::Int64, + b::ZePtr{Ptr{ComplexF64}}, ldb::Int64, + beta::ComplexF64, c::ZePtr{Ptr{ComplexF64}}, + ldc::Int64, group_count::Int64)::Cvoid +end + function onemklSsymm(device_queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc) @ccall liboneapi_support.onemklSsymm(device_queue::syclQueue_t, left_right::onemklSide, diff --git a/test/onemkl.jl b/test/onemkl.jl index 536c7022..51d229e7 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -850,7 +850,6 @@ end end end end - @testset for T in intersect(eltypes, [Float16, Float32, Float64, ComplexF32, ComplexF64]) @testset "gemm!" begin alpha = rand(T) @@ -882,7 +881,7 @@ end @test C2 ≈ h_C2 @test_throws ArgumentError mul!(dhA, dhA, dsA) @test_throws DimensionMismatch mul!(d_C1, d_A, dsA) - + d_c = oneMKL.gemm('N', 'N', d_A, d_B) C = A * B C2 = d_A * d_B @@ -892,4 +891,52 @@ end @test C ≈ h_C2 end end +end + + +@testset "BLAS Extension" begin + @testset for T in [Float16, Float32, Float64, ComplexF32, ComplexF64] + alpha = rand(T) + beta = rand(T) + group_count = 20 + # generate matrices + bA = [rand(T,m,k) for i in 1:group_count] + bB = [rand(T,k,n) for i in 1:group_count] + bC = [rand(T,m,n) for i in 1:group_count] + # move to device + bd_A = oneArray{T, 2}[] + bd_B = oneArray{T, 2}[] + bd_C = oneArray{T, 2}[] + bd_bad = oneArray{T, 2}[] + for i in 1:length(bA) + push!(bd_A,oneArray(bA[i])) + push!(bd_B,oneArray(bB[i])) + push!(bd_C,oneArray(bC[i])) + if i < length(bA) - 2 + push!(bd_bad,oneArray(bC[i])) + end + end + + @testset "gemm_batched!" begin + # C = (alpha*A)*B + beta*C + oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_B,beta,bd_C) + for i in 1:length(bd_C) + bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i] + h_C = Array(bd_C[i]) + #compare + @test bC[i] ≈ h_C + end + @test_throws DimensionMismatch oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C) + end + + @testset "gemm_batched" begin + bd_C = oneMKL.gemm_batched('N','N',bd_A,bd_B) + for i in 1:length(bA) + bC = bA[i]*bB[i] + h_C = Array(bd_C[i]) + @test bC ≈ h_C + end + @test_throws DimensionMismatch oneMKL.gemm_batched('N','N',alpha,bd_A,bd_bad) + end + end end \ No newline at end of file