Skip to content
206 changes: 205 additions & 1 deletion deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "onemkl.h"
#include "sycl.hpp"

#include <iostream>
#include <exception>
#include <memory>
#include <oneapi/mkl.hpp>

// This is a workaround to flush MKL submissions into Level-zero queue, using
Expand Down Expand Up @@ -51,6 +53,90 @@ oneapi::mkl::side convert(onemklSide val) {
}
}

template <typename T>
class gemmBatchInfo {
public:
int64_t *m_mbuf = nullptr;
int64_t *m_nbuf = nullptr;
int64_t *m_kbuf = nullptr;
int64_t *m_ldabuf = nullptr;
int64_t *m_ldbbuf = nullptr;
int64_t *m_ldcbuf = nullptr;
oneapi::mkl::transpose *m_transa = nullptr;
oneapi::mkl::transpose *m_transb = nullptr;
T *m_alphabuf = nullptr;
T *m_betabuf = nullptr;
int64_t *m_group_size = nullptr;
sycl::device m_device;
sycl::context m_context;
oneapi::mkl::transpose m_ta;
oneapi::mkl::transpose m_tb;

// Constructor
gemmBatchInfo(syclQueue_t device_queue,
int64_t group_count,
onemklTranspose transa,
onemklTranspose transb,
int64_t m, int64_t n, int64_t k,
int64_t lda, int64_t ldb, int64_t ldc,
T alpha, T beta) {
// Get device and context info from device_queue
auto main_queue = device_queue->val;
m_device = main_queue.get_device();
m_context = main_queue.get_context();
try {
// Allocate uniform arrays of m,n,k,lda,ldb,ldc,alpha,beta
// group_size and transpose_a, transpose_b supporting oneMKL
// gemm_batch API
m_mbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context);
m_nbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context);
m_kbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context);
m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context);
m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context);
m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context);
m_alphabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context);
m_betabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context);
m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose),
m_device, m_context);
m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose),
m_device, m_context);
m_ta = convert(transa);
m_tb = convert(transb);
m_group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context);
} catch(const std::bad_alloc& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
for (int i = 0; i < group_count; i++) {
m_mbuf[i] = m;
m_nbuf[i] = n;
m_kbuf[i] = k;
m_ldabuf[i] = lda;
m_ldbbuf[i] = ldb;
m_ldcbuf[i] = ldc;
m_alphabuf[i] = alpha;
m_betabuf[i] = beta;
m_transa[i] = m_ta;
m_transb[i] = m_tb;
m_group_size[i] = 1;
}
};

// Destructor
~gemmBatchInfo() {
free(m_mbuf, m_context);
free(m_nbuf, m_context);
free(m_kbuf, m_context);
free(m_ldabuf, m_context);
free(m_ldbbuf, m_context);
free(m_ldcbuf, m_context);
free(m_alphabuf, m_context);
free(m_betabuf, m_context);
free(m_transa, m_context);
free(m_transb, m_context);
free(m_group_size, m_context);
}
};

extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
onemklTranspose transB, int64_t m, int64_t n,
int64_t k, uint16_t alpha, const short *A, int64_t lda,
Expand Down Expand Up @@ -122,6 +208,124 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
return 0;
}

extern "C" void onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, uint16_t alpha,
const short **a, int64_t lda, const short **b,
int64_t ldb, uint16_t beta, short **c,
int64_t ldc, int64_t group_count) {
gemmBatchInfo<sycl::half> gemmInfo(device_queue, group_count, transa, transb,
m, n, k, lda, ldb, ldc, sycl::bit_cast<sycl::half>(alpha),
sycl::bit_cast<sycl::half>(beta));
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
&gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0],
&gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0],
reinterpret_cast<const sycl::half **>(&a[0]), &gemmInfo.m_ldabuf[0],
reinterpret_cast<const sycl::half **>(&b[0]), &gemmInfo.m_ldbbuf[0],
&gemmInfo.m_betabuf[0], reinterpret_cast<sycl::half **>(&c[0]),
&gemmInfo.m_ldcbuf[0], group_count, &gemmInfo.m_group_size[0]);

__FORCE_MKL_FLUSH__(status);

}

extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, float alpha,
const float **a, int64_t lda, const float **b,
int64_t ldb, float beta, float **c,
int64_t ldc, int64_t group_count) {
gemmBatchInfo<float> gemmInfo(device_queue, group_count, transa, transb,
m, n, k, lda, ldb, ldc, alpha, beta);

auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
&gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0],
&gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0],
(const float **)&a[0], &gemmInfo.m_ldabuf[0],
(const float **)&b[0], &gemmInfo.m_ldbbuf[0],
&gemmInfo.m_betabuf[0], &c[0], &gemmInfo.m_ldcbuf[0],
group_count, &gemmInfo.m_group_size[0]);

__FORCE_MKL_FLUSH__(status);

}

extern "C" void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, double alpha,
const double **a, int64_t lda, const double **b,
int64_t ldb, double beta, double **c,
int64_t ldc, int64_t group_count) {
gemmBatchInfo<double> gemmInfo(device_queue, group_count, transa, transb,
m, n, k, lda, ldb, ldc, alpha, beta);

auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
&gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0],
&gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0],
(const double **)&a[0], &gemmInfo.m_ldabuf[0],
(const double **)&b[0], &gemmInfo.m_ldbbuf[0],
&gemmInfo.m_betabuf[0], &c[0], &gemmInfo.m_ldcbuf[0],
group_count, &gemmInfo.m_group_size[0]);

__FORCE_MKL_FLUSH__(status);

}

extern "C" void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, float _Complex alpha,
const float _Complex **a, int64_t lda,
const float _Complex **b,
int64_t ldb, float _Complex beta, float _Complex **c,
int64_t ldc, int64_t group_count) {
gemmBatchInfo<std::complex<float>> gemmInfo(device_queue, group_count, transa, transb,
m, n, k, lda, ldb, ldc, alpha, beta);

auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
&gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0],
&gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0],
reinterpret_cast<const std::complex<float> **>(&a[0]),
&gemmInfo.m_ldabuf[0],
reinterpret_cast<const std::complex<float> **>(&b[0]),
&gemmInfo.m_ldbbuf[0],
&gemmInfo.m_betabuf[0],
reinterpret_cast<std::complex<float> **>(&c[0]), &gemmInfo.m_ldcbuf[0],
group_count, &gemmInfo.m_group_size[0]);

__FORCE_MKL_FLUSH__(status);

}

extern "C" void onemklZgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, double _Complex alpha,
const double _Complex **a, int64_t lda,
const double _Complex **b,
int64_t ldb, double _Complex beta,
double _Complex **c,
int64_t ldc, int64_t group_count) {
gemmBatchInfo<std::complex<double>> gemmInfo(device_queue, group_count, transa, transb,
m, n, k, lda, ldb, ldc, alpha, beta);

auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
&gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0],
&gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0],
reinterpret_cast<const std::complex<double> **>(&a[0]),
&gemmInfo.m_ldabuf[0],
reinterpret_cast<const std::complex<double> **>(&b[0]),
&gemmInfo.m_ldbbuf[0],
&gemmInfo.m_betabuf[0],
reinterpret_cast<std::complex<double> **>(&c[0]), &gemmInfo.m_ldcbuf[0],
group_count, &gemmInfo.m_group_size[0]);

__FORCE_MKL_FLUSH__(status);
}

extern "C" void onemklSsymm(syclQueue_t device_queue, onemklSide left_right,
onemklUplo upper_lower, int64_t m, int64_t n,
float alpha, const float *a, int64_t lda, const float *b,
Expand Down
39 changes: 39 additions & 0 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,45 @@ int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
const short *B, int64_t ldb, uint16_t beta, short *C,
int64_t ldc);

void onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, uint16_t alpha,
const short **a, int64_t lda, const short **b,
int64_t ldb, uint16_t beta, short **c,
int64_t ldc, int64_t group_count);

void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, float alpha,
const float **a, int64_t lda, const float **b,
int64_t ldb, float beta, float **c,
int64_t ldc, int64_t group_count);

void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, double alpha,
const double **a, int64_t lda, const double **b,
int64_t ldb, double beta, double **c,
int64_t ldc, int64_t group_count);

void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, float _Complex alpha,
const float _Complex **a, int64_t lda,
const float _Complex **b,
int64_t ldb, float _Complex beta,
float _Complex **c, int64_t ldc,
int64_t group_count);

void onemklZgemmBatched(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m,
int64_t n, int64_t k, double _Complex alpha,
const double _Complex **a, int64_t lda,
const double _Complex **b,
int64_t ldb, double _Complex beta,
double _Complex **c, int64_t ldc,
int64_t group_count);

void onemklSsymm(syclQueue_t device_queue, onemklSide left_right,
onemklUplo upper_lower, int64_t m, int64_t n,
float alpha, const float *a, int64_t lda, const float *b,
Expand Down
2 changes: 1 addition & 1 deletion lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module oneMKL

using ..oneAPI

using ..oneAPI: unsafe_free!
using ..oneL0

using ..Support
Expand Down
63 changes: 63 additions & 0 deletions lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,69 @@ function Base.convert(::Type{onemklDiag}, diag::Char)
end
end

# create a batch of pointers in device memory from a batch of device arrays
@inline function unsafe_batch(batch::Vector{<:oneArray{T}}) where {T}
ptrs = pointer.(batch)
return oneArray(ptrs)
end

## (GE) general matrix-matrix multiplication batched
for (fname, elty) in
((:onemklDgemmBatched,:Float64),
(:onemklSgemmBatched,:Float32),
(:onemklHgemmBatched,:Float16),
(:onemklCgemmBatched,:ComplexF32),
(:onemklZgemmBatched,:ComplexF64))
@eval begin
function gemm_batched!(transA::Char,
transB::Char,
alpha::Number,
A::Vector{<:oneStridedMatrix{$elty}},
B::Vector{<:oneStridedMatrix{$elty}},
beta::Number,
C::Vector{<:oneStridedMatrix{$elty}})
if length(A) != length(B) || length(A) != length(C)
throw(DimensionMismatch(""))
end
for (As,Bs,Cs) in zip(A,B,C)
m = size(As, transA == 'N' ? 1 : 2)
k = size(As, transA == 'N' ? 2 : 1)
n = size(Bs, transB == 'N' ? 2 : 1)
if m != size(Cs,1) || n != size(Cs,2) || k != size(Bs, transB == 'N' ? 1 : 2)
throw(DimensionMismatch(""))
end
end

m = size(A[1], transA == 'N' ? 1 : 2)
k = size(A[1], transA == 'N' ? 2 : 1)
n = size(B[1], transB == 'N' ? 2 : 1)
lda = max(1,stride(A[1],2))
ldb = max(1,stride(B[1],2))
ldc = max(1,stride(C[1],2))
Aptrs = unsafe_batch(A)
Bptrs = unsafe_batch(B)
Cptrs = unsafe_batch(C)
queue = global_queue(context(A[1]), device(A[1]))
$fname(sycl_queue(queue), transA, transB, m, n, k, alpha, Aptrs, lda, Bptrs,
ldb, beta, Cptrs, ldc, length(A))
unsafe_free!(Cptrs)
unsafe_free!(Bptrs)
unsafe_free!(Aptrs)
C
end
end
end

function gemm_batched(transA::Char, transB::Char, alpha::Number,
A::Vector{<:oneStridedMatrix{T}}, B::Vector{<:oneStridedMatrix{T}}) where T
C = oneMatrix{T}[similar(B[1], (size(A[1], transA == 'N' ? 1 : 2),size(B[1], transB == 'N' ? 2 : 1))) for i in 1:length(A)]
gemm_batched!(transA, transB, alpha, A, B, zero(T), C )
end
function gemm_batched(transA::Char, transB::Char,
A::Vector{<:oneStridedMatrix{T}}, B::Vector{<:oneStridedMatrix{T}}) where T
gemm_batched(transA, transB, one(T), A, B)
end

## (L3: symm) symmetric matrix-matrix and matrix-vector multiplication
for (fname, elty) in ((:onemklSsymm, :Float32),
(:onemklDsymm, :Float64),
Expand Down
Loading