From 2e8f8d256943890875bf24942deb7e0d39d4f523 Mon Sep 17 00:00:00 2001 From: sbalint98 Date: Mon, 22 Mar 2021 13:59:20 +0100 Subject: [PATCH 1/3] [half]Add ifdefs around usages of half Use DISABLE_HALF_RUTINES, use cmake config to set --- include/oneapi/mkl/blas.hxx | 4 ++-- .../oneapi/mkl/blas/detail/blas_ct_backends.hpp | 1 + .../oneapi/mkl/blas/detail/blas_ct_backends.hxx | 4 ++-- include/oneapi/mkl/blas/detail/blas_loader.hpp | 1 + include/oneapi/mkl/blas/detail/blas_loader.hxx | 2 ++ .../oneapi/mkl/blas/detail/cublas/blas_ct.hpp | 1 + .../oneapi/mkl/blas/detail/cublas/blas_ct.hxx | 4 ++-- .../blas/detail/cublas/onemkl_blas_cublas.hpp | 1 + .../blas/detail/cublas/onemkl_blas_cublas.hxx | 4 ++-- .../oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp | 1 + .../oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx | 4 ++-- .../oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx | 4 ++-- .../oneapi/mkl/blas/detail/netlib/blas_ct.hxx | 4 ++-- .../mkl/blas/detail/onemkl_blas_backends.hxx | 4 ++-- include/oneapi/mkl/blas/predicates.hpp | 1 + include/oneapi/mkl/blas/predicates.hxx | 7 ++++--- include/oneapi/mkl/detail/backend_selector.hpp | 2 ++ src/blas/backends/cublas/cublas_level3.cpp | 16 ++++++++-------- .../backends/cublas/mkl_blas_cublas_wrappers.cpp | 4 ++++ src/blas/backends/mklcpu/mklcpu_level3.cxx | 5 ++--- src/blas/backends/mklcpu/mklcpu_wrappers.cpp | 4 ++++ src/blas/backends/mklgpu/mklgpu_common.hpp | 4 ++-- src/blas/backends/mklgpu/mklgpu_level3.cxx | 4 ++-- src/blas/backends/mklgpu/mklgpu_wrappers.cpp | 4 ++++ src/blas/blas_loader.cpp | 8 ++++---- src/blas/function_table.hpp | 6 ++++++ src/config.hpp.in | 1 + src/rng/backends/mklcpu/mrg32k3a.cpp | 2 ++ src/rng/backends/mklcpu/philox4x32x10.cpp | 2 ++ 29 files changed, 71 insertions(+), 38 deletions(-) diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index f5bee1232..9291b0f0f 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -261,7 +261,7 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran c, ldc); gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -281,7 +281,7 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran c, ldc); gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp b/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp index aad01181e..2eeef6fae 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp @@ -25,6 +25,7 @@ #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" namespace oneapi { diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index dea57bb71..e4270b1b6 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -413,7 +413,7 @@ static inline void gemm(backend_selector selector, transpose t std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); - +#ifndef DISABLE_HALF_RUTINES static inline void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, @@ -425,7 +425,7 @@ static inline void gemm(backend_selector selector, transpose t float alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); - +#endif static inline void herk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, float beta, diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hpp b/include/oneapi/mkl/blas/detail/blas_loader.hpp index fe4bd892e..17abdbc1d 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hpp +++ b/include/oneapi/mkl/blas/detail/blas_loader.hpp @@ -24,6 +24,7 @@ #include #include +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/detail/export.hpp" diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index 60183d179..94b393a1c 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -396,6 +396,7 @@ ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, tran std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); +#ifndef DISABLE_HALF_RUTINES ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, @@ -406,6 +407,7 @@ ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, tran float alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); +#endif ONEMKL_EXPORT void syr2(oneapi::mkl::device libkey, cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp index be19fe05d..6644cdc98 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp @@ -26,6 +26,7 @@ #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp" #include "oneapi/mkl/blas/detail/blas_ct_backends.hpp" diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index bd74e9715..ff090fce0 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -767,7 +767,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp index 0b3fc2fb6..3948625a1 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp @@ -23,6 +23,7 @@ #include #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" namespace oneapi { namespace mkl { diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index bf2f502b0..6484c6f6b 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -496,7 +496,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64 cl::sycl::buffer, 1> &a, std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); - +#ifndef DISABLE_HALF_RUTINES void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -506,7 +506,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64 std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); - +#endif void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp index c35069646..de032e8f9 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp @@ -25,6 +25,7 @@ #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" #include "oneapi/mkl/blas/detail/blas_ct_backends.hpp" diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index 8956f15fc..f158b80a0 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -767,7 +767,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index 1dfee1647..e0889c1a6 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -767,7 +767,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index 935ecc59d..9e283c7e2 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -767,7 +767,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index 866aab302..eb1f1f19a 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -46,7 +46,7 @@ ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); - +#ifndef DISABLE_HALF_RUTINES ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, @@ -58,7 +58,7 @@ ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); - +#endif ONEMKL_EXPORT void symm(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, float alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/predicates.hpp b/include/oneapi/mkl/blas/predicates.hpp index 5a668be89..ebaddc02f 100644 --- a/include/oneapi/mkl/blas/predicates.hpp +++ b/include/oneapi/mkl/blas/predicates.hpp @@ -26,6 +26,7 @@ #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" namespace oneapi { namespace mkl { diff --git a/include/oneapi/mkl/blas/predicates.hxx b/include/oneapi/mkl/blas/predicates.hxx index 19f7a3b76..635cd885e 100644 --- a/include/oneapi/mkl/blas/predicates.hxx +++ b/include/oneapi/mkl/blas/predicates.hxx @@ -1518,6 +1518,7 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo #endif } +#ifndef DISABLE_HALF_RUTINES inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, @@ -1557,7 +1558,7 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo /* add postchecks to queue here for input args. */ #endif } - +#endif inline void syr2_precondition(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, @@ -4749,7 +4750,7 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo /* add postchecks to queue here for input args. */ #endif } - +#ifndef DISABLE_HALF_RUTINES inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, const half *b, std::int64_t ldb, @@ -4769,7 +4770,7 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo /* add postchecks to queue here for input args. */ #endif } - +#endif inline void syr2_precondition(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *x, std::int64_t incx, const float *y, std::int64_t incy, float *a, std::int64_t lda, diff --git a/include/oneapi/mkl/detail/backend_selector.hpp b/include/oneapi/mkl/detail/backend_selector.hpp index b0c763ae0..9b5aef3c4 100644 --- a/include/oneapi/mkl/detail/backend_selector.hpp +++ b/include/oneapi/mkl/detail/backend_selector.hpp @@ -27,6 +27,8 @@ namespace oneapi { namespace mkl { +using namespace cl; + template class backend_selector { public: diff --git a/src/blas/backends/cublas/cublas_level3.cpp b/src/blas/backends/cublas/cublas_level3.cpp index 671a15ea7..67f10e955 100644 --- a/src/blas/backends/cublas/cublas_level3.cpp +++ b/src/blas/backends/cublas/cublas_level3.cpp @@ -109,10 +109,10 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } - +#ifndef DISABLE_HALF_RUTINES GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) - +#endif #undef GEMM_EX_LAUNCHER template @@ -465,14 +465,14 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM - +#ifndef DISABLE_HALF_RUTINES cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { throw unimplemented("blas", "gemm", "for column_major layout"); } - +#endif template inline cl::sycl::event symm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, @@ -860,10 +860,10 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } - +#ifndef DISABLE_HALF_RUTINES GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) - +#endif #undef GEMM_EX_LAUNCHER template @@ -1065,14 +1065,14 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM - +#ifndef DISABLE_HALF_RUTINES cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { throw unimplemented("blas", "gemm", "for row_major layout"); } - +#endif template inline cl::sycl::event symm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, diff --git a/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp index 23554ea2a..b523a6571 100644 --- a/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp +++ b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp @@ -147,8 +147,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, +#ifndef DISABLE_HALF_RUTINES oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, +#endif oneapi::mkl::blas::cublas::column_major::hemm, oneapi::mkl::blas::cublas::column_major::hemm, oneapi::mkl::blas::cublas::column_major::herk, @@ -478,8 +480,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, +#ifndef DISABLE_HALF_RUTINES oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, +#endif oneapi::mkl::blas::cublas::row_major::hemm, oneapi::mkl::blas::cublas::row_major::hemm, oneapi::mkl::blas::cublas::row_major::herk, diff --git a/src/blas/backends/mklcpu/mklcpu_level3.cxx b/src/blas/backends/mklcpu/mklcpu_level3.cxx index 7917be1ca..ed6a2961c 100644 --- a/src/blas/backends/mklcpu/mklcpu_level3.cxx +++ b/src/blas/backends/mklcpu/mklcpu_level3.cxx @@ -96,7 +96,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, }); }); } - +#ifndef DISABLE_HALF_RUTINES void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, @@ -140,7 +140,6 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, }); }); } - void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, @@ -173,7 +172,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, }); }); } - +#endif void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, int64_t lda, cl::sycl::buffer, 1> &b, int64_t ldb, std::complex beta, diff --git a/src/blas/backends/mklcpu/mklcpu_wrappers.cpp b/src/blas/backends/mklcpu/mklcpu_wrappers.cpp index 6bb07f11f..996a8c57c 100644 --- a/src/blas/backends/mklcpu/mklcpu_wrappers.cpp +++ b/src/blas/backends/mklcpu/mklcpu_wrappers.cpp @@ -148,8 +148,10 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, + #ifndef DISABLE_HALF_RUTINES oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, + #endif oneapi::mkl::blas::mklcpu::column_major::hemm, oneapi::mkl::blas::mklcpu::column_major::hemm, oneapi::mkl::blas::mklcpu::column_major::herk, @@ -479,8 +481,10 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, + #ifndef DISABLE_HALF_RUTINES oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, + #endif oneapi::mkl::blas::mklcpu::row_major::hemm, oneapi::mkl::blas::mklcpu::row_major::hemm, oneapi::mkl::blas::mklcpu::row_major::herk, diff --git a/src/blas/backends/mklgpu/mklgpu_common.hpp b/src/blas/backends/mklgpu/mklgpu_common.hpp index 86c4512e3..0cf15e056 100644 --- a/src/blas/backends/mklgpu/mklgpu_common.hpp +++ b/src/blas/backends/mklgpu/mklgpu_common.hpp @@ -779,7 +779,7 @@ void cgemmt(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_UPLO upper_lower, MKL cl::sycl::buffer, 1> &a, int64_t lda, cl::sycl::buffer, 1> &b, int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, int64_t ldc); - +#ifndef DISABLE_HALF_RUTINES void hgemm(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, int64_t m, int64_t n, int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, @@ -789,7 +789,7 @@ void gemm_f16f16f32(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_TRANSPOSE tra MKL_TRANSPOSE transb, int64_t m, int64_t n, int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, int64_t ldc); - +#endif cl::sycl::event gemm_s8u8s32_sycl(cl::sycl::queue *queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, CBLAS_OFFSET offsetc, int64_t m, int64_t n, int64_t k, float alpha, cl::sycl::buffer *a, diff --git a/src/blas/backends/mklgpu/mklgpu_level3.cxx b/src/blas/backends/mklgpu/mklgpu_level3.cxx index 594b33b5d..a091de3b5 100644 --- a/src/blas/backends/mklgpu/mklgpu_level3.cxx +++ b/src/blas/backends/mklgpu/mklgpu_level3.cxx @@ -56,7 +56,7 @@ void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::tr ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -74,7 +74,7 @@ void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::tr ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif void symm(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, float alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, diff --git a/src/blas/backends/mklgpu/mklgpu_wrappers.cpp b/src/blas/backends/mklgpu/mklgpu_wrappers.cpp index b32c32ab9..f5959a3f5 100644 --- a/src/blas/backends/mklgpu/mklgpu_wrappers.cpp +++ b/src/blas/backends/mklgpu/mklgpu_wrappers.cpp @@ -148,8 +148,10 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, +#ifndef DISABLE_HALF_RUTINES oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, +#endif oneapi::mkl::blas::mklgpu::column_major::hemm, oneapi::mkl::blas::mklgpu::column_major::hemm, oneapi::mkl::blas::mklgpu::column_major::herk, @@ -479,8 +481,10 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, +#ifndef DISABLE_HALF_RUTINES oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, +#endif oneapi::mkl::blas::mklgpu::row_major::hemm, oneapi::mkl::blas::mklgpu::row_major::hemm, oneapi::mkl::blas::mklgpu::row_major::herk, diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 206cb5846..77b78f48e 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -877,7 +877,7 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].column_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -893,7 +893,7 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].column_major_gemm_f16f16f32_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif void hemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, @@ -3495,7 +3495,7 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].row_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#ifndef DISABLE_HALF_RUTINES void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -3511,7 +3511,7 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].row_major_gemm_f16f16f32_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - +#endif void hemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index 48c04f29c..0b367eab4 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -24,6 +24,7 @@ #include #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" typedef struct { int version; @@ -567,12 +568,14 @@ typedef struct { cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); +#ifndef DISABLE_HALF_RUTINES void (*column_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, cl::sycl::buffer &c, std::int64_t ldc); + void (*column_major_gemm_f16f16f32_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -580,6 +583,7 @@ typedef struct { cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); +#endif void (*column_major_chemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, @@ -2086,6 +2090,7 @@ typedef struct { cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); +#ifndef DISABLE_HALF_RUTINES void (*row_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, @@ -2098,6 +2103,7 @@ typedef struct { cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); +#endif void (*row_major_chemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, diff --git a/src/config.hpp.in b/src/config.hpp.in index 957820cc9..bf48e47f1 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -25,6 +25,7 @@ #cmakedefine ENABLE_MKLCPU_BACKEND #cmakedefine ENABLE_MKLGPU_BACKEND #cmakedefine ENABLE_NETLIB_BACKEND +#cmakedefine DISABLE_HALF_RUTINES #cmakedefine BUILD_SHARED_LIBS #endif diff --git a/src/rng/backends/mklcpu/mrg32k3a.cpp b/src/rng/backends/mklcpu/mrg32k3a.cpp index 6d38d78d6..163767121 100755 --- a/src/rng/backends/mklcpu/mrg32k3a.cpp +++ b/src/rng/backends/mklcpu/mrg32k3a.cpp @@ -33,6 +33,8 @@ namespace mkl { namespace rng { namespace mklcpu { +using namespace cl; + class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { public: mrg32k3a_impl(cl::sycl::queue queue, std::uint32_t seed) diff --git a/src/rng/backends/mklcpu/philox4x32x10.cpp b/src/rng/backends/mklcpu/philox4x32x10.cpp index f204912f4..b65d375d6 100644 --- a/src/rng/backends/mklcpu/philox4x32x10.cpp +++ b/src/rng/backends/mklcpu/philox4x32x10.cpp @@ -33,6 +33,8 @@ namespace mkl { namespace rng { namespace mklcpu { +using namespace cl; + class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { public: philox4x32x10_impl(cl::sycl::queue queue, std::uint64_t seed) From 2ff8f9f73c857c0c4d040e88516b43e7b19d12eb Mon Sep 17 00:00:00 2001 From: sbalint98 Date: Fri, 21 May 2021 21:50:50 +0200 Subject: [PATCH 2/3] :[cublas] correct typo, use ENABLE_HALF_ROUTINES instead of DISABLE --- include/oneapi/mkl/blas.hxx | 2 +- include/oneapi/mkl/blas/detail/blas_ct_backends.hxx | 2 +- include/oneapi/mkl/blas/detail/blas_loader.hxx | 2 +- include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx | 2 +- .../oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx | 2 +- include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx | 2 +- include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx | 2 +- include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx | 2 +- include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx | 2 +- include/oneapi/mkl/blas/predicates.hxx | 4 ++-- src/blas/backends/cublas/cublas_level3.cpp | 8 ++++---- src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp | 4 ++-- src/blas/backends/mklcpu/mklcpu_level3.cxx | 2 +- src/blas/backends/mklcpu/mklcpu_wrappers.cpp | 4 ++-- src/blas/backends/mklgpu/mklgpu_common.hpp | 2 +- src/blas/backends/mklgpu/mklgpu_level3.cxx | 2 +- src/blas/backends/mklgpu/mklgpu_wrappers.cpp | 4 ++-- src/blas/blas_loader.cpp | 4 ++-- src/blas/function_table.hpp | 4 ++-- src/config.hpp.in | 2 +- 20 files changed, 29 insertions(+), 29 deletions(-) diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index 9291b0f0f..5a1bce850 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -261,7 +261,7 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran c, ldc); gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index e4270b1b6..cce984419 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -413,7 +413,7 @@ static inline void gemm(backend_selector selector, transpose t std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES static inline void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index 94b393a1c..f8288ad01 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -396,7 +396,7 @@ ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, tran std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index ff090fce0..81964ab05 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index 6484c6f6b..18b6cb9b8 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -496,7 +496,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64 cl::sycl::buffer, 1> &a, std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index f158b80a0..24c41a7c9 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index e0889c1a6..92db5a9b1 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index 9e283c7e2..5f03c7656 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -743,7 +743,7 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index eb1f1f19a..f996dd890 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -46,7 +46,7 @@ ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/predicates.hxx b/include/oneapi/mkl/blas/predicates.hxx index 635cd885e..a5a95fe31 100644 --- a/include/oneapi/mkl/blas/predicates.hxx +++ b/include/oneapi/mkl/blas/predicates.hxx @@ -1518,7 +1518,7 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo #endif } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, @@ -4750,7 +4750,7 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo /* add postchecks to queue here for input args. */ #endif } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, const half *b, std::int64_t ldb, diff --git a/src/blas/backends/cublas/cublas_level3.cpp b/src/blas/backends/cublas/cublas_level3.cpp index 67f10e955..c8a701f94 100644 --- a/src/blas/backends/cublas/cublas_level3.cpp +++ b/src/blas/backends/cublas/cublas_level3.cpp @@ -109,7 +109,7 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) #endif @@ -465,7 +465,7 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, @@ -860,7 +860,7 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) #endif @@ -1065,7 +1065,7 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, diff --git a/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp index b523a6571..f8cc31ffa 100644 --- a/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp +++ b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp @@ -147,7 +147,7 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, #endif @@ -480,7 +480,7 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, #endif diff --git a/src/blas/backends/mklcpu/mklcpu_level3.cxx b/src/blas/backends/mklcpu/mklcpu_level3.cxx index ed6a2961c..9e6dcbd7d 100644 --- a/src/blas/backends/mklcpu/mklcpu_level3.cxx +++ b/src/blas/backends/mklcpu/mklcpu_level3.cxx @@ -96,7 +96,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, }); }); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, diff --git a/src/blas/backends/mklcpu/mklcpu_wrappers.cpp b/src/blas/backends/mklcpu/mklcpu_wrappers.cpp index 996a8c57c..d65c8bac2 100644 --- a/src/blas/backends/mklcpu/mklcpu_wrappers.cpp +++ b/src/blas/backends/mklcpu/mklcpu_wrappers.cpp @@ -148,7 +148,7 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, - #ifndef DISABLE_HALF_RUTINES + #ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, #endif @@ -481,7 +481,7 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, - #ifndef DISABLE_HALF_RUTINES + #ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, #endif diff --git a/src/blas/backends/mklgpu/mklgpu_common.hpp b/src/blas/backends/mklgpu/mklgpu_common.hpp index 0cf15e056..79b4e80f7 100644 --- a/src/blas/backends/mklgpu/mklgpu_common.hpp +++ b/src/blas/backends/mklgpu/mklgpu_common.hpp @@ -779,7 +779,7 @@ void cgemmt(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_UPLO upper_lower, MKL cl::sycl::buffer, 1> &a, int64_t lda, cl::sycl::buffer, 1> &b, int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, int64_t ldc); -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void hgemm(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, int64_t m, int64_t n, int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, diff --git a/src/blas/backends/mklgpu/mklgpu_level3.cxx b/src/blas/backends/mklgpu/mklgpu_level3.cxx index a091de3b5..20314374d 100644 --- a/src/blas/backends/mklgpu/mklgpu_level3.cxx +++ b/src/blas/backends/mklgpu/mklgpu_level3.cxx @@ -56,7 +56,7 @@ void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::tr ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/src/blas/backends/mklgpu/mklgpu_wrappers.cpp b/src/blas/backends/mklgpu/mklgpu_wrappers.cpp index f5959a3f5..068efa886 100644 --- a/src/blas/backends/mklgpu/mklgpu_wrappers.cpp +++ b/src/blas/backends/mklgpu/mklgpu_wrappers.cpp @@ -148,7 +148,7 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, #endif @@ -481,7 +481,7 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, #endif diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 77b78f48e..316767d6d 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -877,7 +877,7 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].column_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, @@ -3495,7 +3495,7 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].row_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index 0b367eab4..afb70a938 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -568,7 +568,7 @@ typedef struct { cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void (*column_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, @@ -2090,7 +2090,7 @@ typedef struct { cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifndef DISABLE_HALF_RUTINES +#ifdef ENABLE_HALF_ROUTINES void (*row_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, diff --git a/src/config.hpp.in b/src/config.hpp.in index bf48e47f1..e7784c489 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -25,7 +25,7 @@ #cmakedefine ENABLE_MKLCPU_BACKEND #cmakedefine ENABLE_MKLGPU_BACKEND #cmakedefine ENABLE_NETLIB_BACKEND -#cmakedefine DISABLE_HALF_RUTINES +#cmakedefine ENABLE_HALF_ROUTINES #cmakedefine BUILD_SHARED_LIBS #endif From 4aa199a3569ea01d49c28ecd207e6f4c07713c9f Mon Sep 17 00:00:00 2001 From: sbalint98 Date: Wed, 2 Jun 2021 18:46:55 +0200 Subject: [PATCH 3/3] Removed unnecessary ifdefs and replaced half with cl::sycl::half --- .gitignore | 2 +- include/oneapi/mkl/blas.hxx | 12 +++--- .../mkl/blas/detail/blas_ct_backends.hxx | 12 +++--- .../oneapi/mkl/blas/detail/blas_loader.hxx | 12 +++--- .../oneapi/mkl/blas/detail/cublas/blas_ct.hxx | 12 +++--- .../blas/detail/cublas/onemkl_blas_cublas.hxx | 12 +++--- .../oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx | 12 +++--- .../oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx | 12 +++--- .../oneapi/mkl/blas/detail/netlib/blas_ct.hxx | 12 +++--- .../mkl/blas/detail/onemkl_blas_backends.hxx | 12 +++--- include/oneapi/mkl/blas/predicates.hxx | 40 +++++++++---------- .../rng/detail/curand/onemkl_rng_curand.hpp | 2 +- src/blas/backends/cublas/cublas_level3.cpp | 34 +++++++++------- .../cublas/mkl_blas_cublas_wrappers.cpp | 4 -- src/blas/backends/mklcpu/mklcpu_level3.cpp | 1 + src/blas/backends/mklcpu/mklcpu_level3.cxx | 21 ++++++---- src/blas/backends/mklcpu/mklcpu_wrappers.cpp | 4 -- src/blas/backends/mklgpu/mklgpu_common.hpp | 9 ++--- src/blas/backends/mklgpu/mklgpu_level3.cxx | 12 +++--- src/blas/backends/mklgpu/mklgpu_wrappers.cpp | 4 -- src/blas/backends/netlib/netlib_level3.cxx | 10 ++--- src/blas/blas_loader.cpp | 24 +++++------ src/blas/function_table.hpp | 24 +++++------ 23 files changed, 133 insertions(+), 166 deletions(-) diff --git a/.gitignore b/.gitignore index 2a3a157c9..631826c77 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,4 @@ .git/ # Build -build/ +build*/ diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index 5a1bce850..6bc4f0d0d 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -261,11 +261,10 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran c, ldc); gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); detail::gemm(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -273,15 +272,14 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran } static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); detail::gemm(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index cce984419..62143da20 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -413,19 +413,17 @@ static inline void gemm(backend_selector selector, transpose t std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifdef ENABLE_HALF_ROUTINES static inline void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - half alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + cl::sycl::half alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); static inline void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - float alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); -#endif static inline void herk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, float beta, diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index f8288ad01..2ef6f7069 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -396,18 +396,16 @@ ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, tran std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifdef ENABLE_HALF_ROUTINES ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - half alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + cl::sycl::half alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - float alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); -#endif ONEMKL_EXPORT void syr2(oneapi::mkl::device libkey, cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index 81964ab05..8a4f7ff00 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::cublas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index 18b6cb9b8..f81325f2d 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -496,17 +496,15 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64 cl::sycl::buffer, 1> &a, std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifdef ENABLE_HALF_ROUTINES void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); -#endif void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index 24c41a7c9..33217b8ff 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::mklcpu::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index 92db5a9b1..f6abd97aa 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::mklgpu::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index 5f03c7656..9d58cbfea 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::netlib::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index f996dd890..cd58aa9f9 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -46,19 +46,17 @@ ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifdef ENABLE_HALF_ROUTINES ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); -#endif ONEMKL_EXPORT void symm(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, float alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/predicates.hxx b/include/oneapi/mkl/blas/predicates.hxx index a5a95fe31..88cbb63ab 100644 --- a/include/oneapi/mkl/blas/predicates.hxx +++ b/include/oneapi/mkl/blas/predicates.hxx @@ -1518,22 +1518,21 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo #endif } -#ifdef ENABLE_HALF_ROUTINES inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add prechecks to queue here for input args. */ #endif } inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add postchecks to queue here for input args. */ #endif @@ -1541,8 +1540,8 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add prechecks to queue here for input args. */ @@ -1551,14 +1550,13 @@ inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpos inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add postchecks to queue here for input args. */ #endif } -#endif inline void syr2_precondition(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, @@ -4750,11 +4748,10 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo /* add postchecks to queue here for input args. */ #endif } -#ifdef ENABLE_HALF_ROUTINES inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - const half *a, std::int64_t lda, const half *b, std::int64_t ldb, - half beta, half *c, std::int64_t ldc, + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + const cl::sycl::half *a, std::int64_t lda, const cl::sycl::half *b, std::int64_t ldb, + cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { #ifndef ONEMKL_DISABLE_PREDICATES /* add prechecks to queue here for input args. */ @@ -4762,15 +4759,14 @@ inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpos } inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - const half *a, std::int64_t lda, const half *b, std::int64_t ldb, - half beta, half *c, std::int64_t ldc, + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + const cl::sycl::half *a, std::int64_t lda, const cl::sycl::half *b, std::int64_t ldb, + cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { #ifndef ONEMKL_DISABLE_PREDICATES /* add postchecks to queue here for input args. */ #endif } -#endif inline void syr2_precondition(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *x, std::int64_t incx, const float *y, std::int64_t incy, float *a, std::int64_t lda, diff --git a/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp b/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp index e63c6ab56..e3d856958 100644 --- a/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp +++ b/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp @@ -50,7 +50,7 @@ * NOTICE. This Software was developed under funding from the U.S. Department * of Energy and the U.S. Government consequently retains certain rights. As * such, the U.S. Government has been granted for itself and others acting on - * its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the + * its becl::sycl::half a paid-up, nonexclusive, irrevocable, worldwide license in the * Software to reproduce, distribute copies to the public, prepare derivative * works, and perform publicly and display publicly, and to permit others to do * so. diff --git a/src/blas/backends/cublas/cublas_level3.cpp b/src/blas/backends/cublas/cublas_level3.cpp index c8a701f94..341f35dee 100644 --- a/src/blas/backends/cublas/cublas_level3.cpp +++ b/src/blas/backends/cublas/cublas_level3.cpp @@ -99,6 +99,7 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, }); }); } +#ifdef ENABLE_HALF_ROUTINES #define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, \ CUDADATATYPE_C) \ @@ -109,10 +110,19 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } -#ifdef ENABLE_HALF_ROUTINES -GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) -GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) +#else +#define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, \ + CUDADATATYPE_C) \ + void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_C alpha, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &b, int64_t ldb, TYPE_C beta, \ + cl::sycl::buffer &c, int64_t ldc) { \ + throw unimplemented("blas", "gemm", "half is disabled"); \ + } #endif +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, cl::sycl::half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) + #undef GEMM_EX_LAUNCHER template @@ -465,14 +475,12 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM -#ifdef ENABLE_HALF_ROUTINES cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, - const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, cl::sycl::half alpha, const cl::sycl::half *a, std::int64_t lda, + const cl::sycl::half *b, std::int64_t ldb, cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { throw unimplemented("blas", "gemm", "for column_major layout"); } -#endif template inline cl::sycl::event symm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, @@ -860,10 +868,8 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } -#ifdef ENABLE_HALF_ROUTINES -GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) -GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) -#endif +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, cl::sycl::half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) #undef GEMM_EX_LAUNCHER template @@ -1065,14 +1071,12 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM -#ifdef ENABLE_HALF_ROUTINES cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, - const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, cl::sycl::half alpha, const cl::sycl::half *a, std::int64_t lda, + const cl::sycl::half *b, std::int64_t ldb, cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { throw unimplemented("blas", "gemm", "for row_major layout"); } -#endif template inline cl::sycl::event symm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, diff --git a/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp index f8cc31ffa..23554ea2a 100644 --- a/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp +++ b/src/blas/backends/cublas/mkl_blas_cublas_wrappers.cpp @@ -147,10 +147,8 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, -#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::cublas::column_major::gemm, oneapi::mkl::blas::cublas::column_major::gemm, -#endif oneapi::mkl::blas::cublas::column_major::hemm, oneapi::mkl::blas::cublas::column_major::hemm, oneapi::mkl::blas::cublas::column_major::herk, @@ -480,10 +478,8 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, -#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::cublas::row_major::gemm, oneapi::mkl::blas::cublas::row_major::gemm, -#endif oneapi::mkl::blas::cublas::row_major::hemm, oneapi::mkl::blas::cublas::row_major::hemm, oneapi::mkl::blas::cublas::row_major::herk, diff --git a/src/blas/backends/mklcpu/mklcpu_level3.cpp b/src/blas/backends/mklcpu/mklcpu_level3.cpp index 694a3eb60..128465e98 100644 --- a/src/blas/backends/mklcpu/mklcpu_level3.cpp +++ b/src/blas/backends/mklcpu/mklcpu_level3.cpp @@ -19,6 +19,7 @@ #include +#include "oneapi/mkl/exceptions.hpp" #include "mklcpu_common.hpp" #include "fp16.hpp" #include "oneapi/mkl/blas/detail/mklcpu/onemkl_blas_mklcpu.hpp" diff --git a/src/blas/backends/mklcpu/mklcpu_level3.cxx b/src/blas/backends/mklcpu/mklcpu_level3.cxx index 9e6dcbd7d..46a966cbc 100644 --- a/src/blas/backends/mklcpu/mklcpu_level3.cxx +++ b/src/blas/backends/mklcpu/mklcpu_level3.cxx @@ -96,11 +96,12 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, }); }); } -#ifdef ENABLE_HALF_ROUTINES + void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, + int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, cl::sycl::half beta, cl::sycl::buffer &c, int64_t ldc) { +#ifdef ENABLE_HALF_ROUTINES auto a_fp16 = a.reinterpret(a.get_range()); auto b_fp16 = b.reinterpret(b.get_range()); auto c_fp16 = c.reinterpret(c.get_range()); @@ -131,7 +132,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, copy_mat(accessor_c, MKLMAJOR, transpose::N, m, n, ldc, 0.0f, f32_c); ::cblas_sgemm(CBLASMAJOR, transa_, transb_, m, n, k, f32_alpha, f32_a, lda, f32_b, ldb, f32_beta, f32_c, ldc); - // copy C back to half + // copy C back to cl::sycl::half fp16 co = 0.0f; copy_mat(f32_c, MKLMAJOR, m, n, ldc, offset::F, &co, accessor_c); ::free(f32_a); @@ -139,11 +140,15 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, ::free(f32_c); }); }); +#else + throw oneapi::mkl::unimplemented("blas", "gemm", "half is disabled"); +#endif } void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, + int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, int64_t ldc) { +#ifdef ENABLE_HALF_ROUTINES auto a_fp16 = a.reinterpret(a.get_range()); auto b_fp16 = b.reinterpret(b.get_range()); queue.submit([&](cl::sycl::handler &cgh) { @@ -171,8 +176,10 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, ::free(f32_b); }); }); -} +#else + throw oneapi::mkl::unimplemented("blas", "cl::sycl::half", "when using hipSYCL"); #endif +} void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, int64_t lda, cl::sycl::buffer, 1> &b, int64_t ldb, std::complex beta, diff --git a/src/blas/backends/mklcpu/mklcpu_wrappers.cpp b/src/blas/backends/mklcpu/mklcpu_wrappers.cpp index d65c8bac2..6bb07f11f 100644 --- a/src/blas/backends/mklcpu/mklcpu_wrappers.cpp +++ b/src/blas/backends/mklcpu/mklcpu_wrappers.cpp @@ -148,10 +148,8 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, - #ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklcpu::column_major::gemm, oneapi::mkl::blas::mklcpu::column_major::gemm, - #endif oneapi::mkl::blas::mklcpu::column_major::hemm, oneapi::mkl::blas::mklcpu::column_major::hemm, oneapi::mkl::blas::mklcpu::column_major::herk, @@ -481,10 +479,8 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, - #ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklcpu::row_major::gemm, oneapi::mkl::blas::mklcpu::row_major::gemm, - #endif oneapi::mkl::blas::mklcpu::row_major::hemm, oneapi::mkl::blas::mklcpu::row_major::hemm, oneapi::mkl::blas::mklcpu::row_major::herk, diff --git a/src/blas/backends/mklgpu/mklgpu_common.hpp b/src/blas/backends/mklgpu/mklgpu_common.hpp index 79b4e80f7..f23e432a1 100644 --- a/src/blas/backends/mklgpu/mklgpu_common.hpp +++ b/src/blas/backends/mklgpu/mklgpu_common.hpp @@ -779,17 +779,16 @@ void cgemmt(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_UPLO upper_lower, MKL cl::sycl::buffer, 1> &a, int64_t lda, cl::sycl::buffer, 1> &b, int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, int64_t ldc); -#ifdef ENABLE_HALF_ROUTINES void hgemm(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, - int64_t m, int64_t n, int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, + int64_t m, int64_t n, int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, cl::sycl::half beta, cl::sycl::buffer &c, int64_t ldc); void gemm_f16f16f32(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, int64_t m, int64_t n, int64_t k, float alpha, - cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, int64_t ldc); -#endif + cl::sycl::event gemm_s8u8s32_sycl(cl::sycl::queue *queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, CBLAS_OFFSET offsetc, int64_t m, int64_t n, int64_t k, float alpha, cl::sycl::buffer *a, diff --git a/src/blas/backends/mklgpu/mklgpu_level3.cxx b/src/blas/backends/mklgpu/mklgpu_level3.cxx index 20314374d..6d7a04b90 100644 --- a/src/blas/backends/mklgpu/mklgpu_level3.cxx +++ b/src/blas/backends/mklgpu/mklgpu_level3.cxx @@ -56,25 +56,23 @@ void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::tr ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { ::oneapi::mkl::gpu::hgemm(queue, MAJOR, ::mkl::cblas_convert(transa), ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { ::oneapi::mkl::gpu::gemm_f16f16f32(queue, MAJOR, ::mkl::cblas_convert(transa), ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif void symm(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, float alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, diff --git a/src/blas/backends/mklgpu/mklgpu_wrappers.cpp b/src/blas/backends/mklgpu/mklgpu_wrappers.cpp index 068efa886..b32c32ab9 100644 --- a/src/blas/backends/mklgpu/mklgpu_wrappers.cpp +++ b/src/blas/backends/mklgpu/mklgpu_wrappers.cpp @@ -148,10 +148,8 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, -#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklgpu::column_major::gemm, oneapi::mkl::blas::mklgpu::column_major::gemm, -#endif oneapi::mkl::blas::mklgpu::column_major::hemm, oneapi::mkl::blas::mklgpu::column_major::hemm, oneapi::mkl::blas::mklgpu::column_major::herk, @@ -481,10 +479,8 @@ extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, -#ifdef ENABLE_HALF_ROUTINES oneapi::mkl::blas::mklgpu::row_major::gemm, oneapi::mkl::blas::mklgpu::row_major::gemm, -#endif oneapi::mkl::blas::mklgpu::row_major::hemm, oneapi::mkl::blas::mklgpu::row_major::hemm, oneapi::mkl::blas::mklgpu::row_major::herk, diff --git a/src/blas/backends/netlib/netlib_level3.cxx b/src/blas/backends/netlib/netlib_level3.cxx index 54ce343da..693905fc3 100644 --- a/src/blas/backends/netlib/netlib_level3.cxx +++ b/src/blas/backends/netlib/netlib_level3.cxx @@ -92,9 +92,9 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, } void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); #endif @@ -104,8 +104,8 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64 } void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, + int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, int64_t ldc) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 316767d6d..b864bafe9 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -877,23 +877,21 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].column_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].column_major_hgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].column_major_gemm_f16f16f32_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif void hemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, @@ -3495,23 +3493,21 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].row_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#ifdef ENABLE_HALF_ROUTINES void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].row_major_hgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].row_major_gemm_f16f16f32_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -#endif void hemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index afb70a938..6f05e57ab 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -568,22 +568,20 @@ typedef struct { cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifdef ENABLE_HALF_ROUTINES void (*column_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, - std::int64_t ldb, half beta, cl::sycl::buffer &c, + std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb, cl::sycl::half beta, cl::sycl::buffer &c, std::int64_t ldc); void (*column_major_gemm_f16f16f32_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); -#endif void (*column_major_chemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, @@ -2090,20 +2088,18 @@ typedef struct { cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); -#ifdef ENABLE_HALF_ROUTINES void (*row_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, - half beta, cl::sycl::buffer &c, std::int64_t ldc); + std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + cl::sycl::half beta, cl::sycl::buffer &c, std::int64_t ldc); void (*row_major_gemm_f16f16f32_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); -#endif void (*row_major_chemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha,