1616#include < sycl/sycl.hpp>
1717#include < sycl/half_type.hpp>
1818#include < syclcompat/math.hpp>
19- #include < oneapi/mkl .hpp>
19+ #include < oneapi/math .hpp>
2020#include < map>
2121
2222#include " ggml.h"
@@ -83,13 +83,36 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8383}
8484
8585template <typename Ts> struct matrix_info_t {
86- oneapi::mkl ::transpose transpose_info[2 ];
86+ oneapi::math ::transpose transpose_info[2 ];
8787 Ts value_info[2 ];
8888 std::int64_t size_info[3 ];
8989 std::int64_t ld_info[3 ];
9090 std::int64_t groupsize_info;
9191};
9292
93+ inline auto get_onemath_backend (sycl::queue& queue)
94+ #ifdef GGML_SYCL_GENERIC
95+ -> sycl::queue&
96+ #endif
97+ {
98+ // If the backend is known at compile-time, use oneMath backend_selector to use
99+ // compile-time dispatching and avoid the need to dlopen libraries. Otherwise
100+ // fallback to runtime dispatching.
101+ #if defined(GGML_SYCL_INTEL_CPU)
102+ return oneapi::math::backend_selector<oneapi::math::backend::mklcpu>{ queue };
103+ #elif defined(GGML_SYCL_INTEL_GPU)
104+ return oneapi::math::backend_selector<oneapi::math::backend::mklgpu>{ queue };
105+ #elif defined(GGML_SYCL_NVIDIA)
106+ return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
107+ #elif defined(GGML_SYCL_AMD)
108+ return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
109+ #elif defined(GGML_SYCL_GENERIC)
110+ return queue;
111+ #else
112+ static_assert (false , " Unsupported backend" );
113+ #endif
114+ }
115+
93116namespace dpct
94117{
95118 typedef sycl::queue *queue_ptr;
@@ -1687,8 +1710,8 @@ namespace dpct
16871710 namespace detail
16881711 {
16891712 template <class Ta , class Tb , class Tc , class Ts >
1690- inline void gemm_impl (sycl::queue &q, oneapi::mkl ::transpose a_trans,
1691- oneapi::mkl ::transpose b_trans, int m, int n, int k,
1713+ inline void gemm_impl (sycl::queue &q, oneapi::math ::transpose a_trans,
1714+ oneapi::math ::transpose b_trans, int m, int n, int k,
16921715 const void *alpha, const void *a, int lda, const void *b,
16931716 int ldb, const void *beta, void *c, int ldc)
16941717 {
@@ -1697,14 +1720,8 @@ namespace dpct
16971720 auto data_a = get_memory<const Ta>(a);
16981721 auto data_b = get_memory<const Tb>(b);
16991722 auto data_c = get_memory<Tc>(c);
1700- #ifdef GGML_SYCL_NVIDIA
1701- oneapi::mkl::blas::column_major::gemm (oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1702- a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1703- beta_value, data_c, ldc);
1704- #else
1705- oneapi::mkl::blas::column_major::gemm (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1723+ oneapi::math::blas::column_major::gemm (get_onemath_backend (q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
17061724 beta_value, data_c, ldc);
1707- #endif
17081725 }
17091726
17101727 template <typename VecT, class BinaryOperation , class = void >
@@ -1735,7 +1752,7 @@ namespace dpct
17351752 };
17361753
17371754 template <class Ta , class Tb , class Tc , class Ts >
1738- inline void gemm_batch_impl (sycl::queue & q, oneapi::mkl ::transpose a_trans, oneapi::mkl ::transpose b_trans,
1755+ inline void gemm_batch_impl (sycl::queue & q, oneapi::math ::transpose a_trans, oneapi::math ::transpose b_trans,
17391756 int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
17401757 int ldb, const void * beta, void ** c, int ldc, int batch_size,
17411758 matrix_info_t <float > * matrix_info) {
@@ -1754,28 +1771,18 @@ namespace dpct
17541771 matrix_info->ld_info [2 ] = ldc;
17551772 matrix_info->groupsize_info = batch_size;
17561773
1757- #ifdef GGML_SYCL_NVIDIA
1758- sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1759- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1760- matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1761- matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info->value_info ),
1762- reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1763- matrix_info->ld_info + 1 , reinterpret_cast <Ts *>(matrix_info->value_info + 1 ),
1764- reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1765- #else
1766- sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1767- q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1774+ sycl::event e = oneapi::math::blas::column_major::gemm_batch (
1775+ get_onemath_backend (q), matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
17681776 matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info->value_info ),
17691777 reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
17701778 matrix_info->ld_info + 1 , reinterpret_cast <Ts *>(matrix_info->value_info + 1 ),
17711779 reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1772- #endif
17731780 }
17741781
17751782 template <class Ta , class Tb , class Tc , class Ts >
17761783 inline void
1777- gemm_batch_impl (sycl::queue &q, oneapi::mkl ::transpose a_trans,
1778- oneapi::mkl ::transpose b_trans, int m, int n,
1784+ gemm_batch_impl (sycl::queue &q, oneapi::math ::transpose a_trans,
1785+ oneapi::math ::transpose b_trans, int m, int n,
17791786 int k, const void *alpha, const void *a, int lda,
17801787 long long int stride_a, const void *b, int ldb,
17811788 long long int stride_b, const void *beta, void *c,
@@ -1786,16 +1793,9 @@ namespace dpct
17861793 auto data_a = get_memory<const Ta>(a);
17871794 auto data_b = get_memory<const Tb>(b);
17881795 auto data_c = get_memory<Tc>(c);
1789- #ifdef GGML_SYCL_NVIDIA
1790- oneapi::mkl::blas::column_major::gemm_batch (
1791- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1792- alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1793- batch_size);
1794- #else
1795- oneapi::mkl::blas::column_major::gemm_batch (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1796+ oneapi::math::blas::column_major::gemm_batch (get_onemath_backend (q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
17961797 stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
17971798 stride_c, batch_size);
1798- #endif
17991799 }
18001800
18011801 } // namespace detail
@@ -2259,8 +2259,8 @@ namespace dpct
22592259 sycl::range<3 >(x, y, 1 ), direction);
22602260 }
22612261
2262- inline void gemm (sycl::queue &q, oneapi::mkl ::transpose a_trans,
2263- oneapi::mkl ::transpose b_trans, int m, int n, int k,
2262+ inline void gemm (sycl::queue &q, oneapi::math ::transpose a_trans,
2263+ oneapi::math ::transpose b_trans, int m, int n, int k,
22642264 const void *alpha, const void *a, library_data_t a_type,
22652265 int lda, const void *b, library_data_t b_type, int ldb,
22662266 const void *beta, void *c, library_data_t c_type, int ldc,
@@ -2329,7 +2329,7 @@ namespace dpct
23292329 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
23302330 library_data_t ::real_float, library_data_t ::real_float):
23312331 {
2332- detail::gemm_impl<oneapi::mkl ::bfloat16, oneapi::mkl ::bfloat16, float ,
2332+ detail::gemm_impl<oneapi::math ::bfloat16, oneapi::math ::bfloat16, float ,
23332333 float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
23342334 ldb, beta, c, ldc);
23352335 break ;
@@ -2369,8 +2369,8 @@ namespace dpct
23692369 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
23702370 library_data_t ::real_bfloat16, library_data_t ::real_float):
23712371 {
2372- detail::gemm_impl<oneapi::mkl ::bfloat16, oneapi::mkl ::bfloat16,
2373- oneapi::mkl ::bfloat16, float >(
2372+ detail::gemm_impl<oneapi::math ::bfloat16, oneapi::math ::bfloat16,
2373+ oneapi::math ::bfloat16, float >(
23742374 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
23752375 break ;
23762376 }
@@ -2412,7 +2412,7 @@ namespace dpct
24122412 // / \param [in] ldc Leading dimension of C.
24132413 // / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24142414 // / \param [in] scaling_type Data type of the scaling factors.
2415- inline void gemm_batch (sycl::queue & q, oneapi::mkl ::transpose a_trans, oneapi::mkl ::transpose b_trans, int m,
2415+ inline void gemm_batch (sycl::queue & q, oneapi::math ::transpose a_trans, oneapi::math ::transpose b_trans, int m,
24162416 int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
24172417 const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
24182418 library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2450,15 +2450,15 @@ namespace dpct
24502450 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
24512451 library_data_t ::real_bfloat16, library_data_t ::real_float):
24522452 {
2453- detail::gemm_batch_impl<oneapi::mkl ::bfloat16, oneapi::mkl ::bfloat16, oneapi::mkl ::bfloat16, float >(
2453+ detail::gemm_batch_impl<oneapi::math ::bfloat16, oneapi::math ::bfloat16, oneapi::math ::bfloat16, float >(
24542454 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24552455 break ;
24562456 }
24572457 case detail::get_type_combination_id (
24582458 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
24592459 library_data_t ::real_float, library_data_t ::real_float):
24602460 {
2461- detail::gemm_batch_impl<oneapi::mkl ::bfloat16, oneapi::mkl ::bfloat16, float , float >(
2461+ detail::gemm_batch_impl<oneapi::math ::bfloat16, oneapi::math ::bfloat16, float , float >(
24622462 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24632463 break ;
24642464 }
@@ -2534,8 +2534,8 @@ namespace dpct
25342534 // / \param [in] stride_c Stride between the different C matrices.
25352535 // / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
25362536 // / \param [in] scaling_type Data type of the scaling factors.
2537- inline void gemm_batch (sycl::queue &q, oneapi::mkl ::transpose a_trans,
2538- oneapi::mkl ::transpose b_trans, int m, int n, int k,
2537+ inline void gemm_batch (sycl::queue &q, oneapi::math ::transpose a_trans,
2538+ oneapi::math ::transpose b_trans, int m, int n, int k,
25392539 const void *alpha, const void *a, library_data_t a_type,
25402540 int lda, long long int stride_a, const void *b,
25412541 library_data_t b_type, int ldb, long long int stride_b,
@@ -2611,8 +2611,8 @@ namespace dpct
26112611 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
26122612 library_data_t ::real_bfloat16, library_data_t ::real_float):
26132613 {
2614- detail::gemm_batch_impl<oneapi::mkl ::bfloat16, oneapi::mkl ::bfloat16,
2615- oneapi::mkl ::bfloat16, float >(
2614+ detail::gemm_batch_impl<oneapi::math ::bfloat16, oneapi::math ::bfloat16,
2615+ oneapi::math ::bfloat16, float >(
26162616 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
26172617 beta, c, ldc, stride_c, batch_size);
26182618 break ;
@@ -2621,7 +2621,7 @@ namespace dpct
26212621 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
26222622 library_data_t ::real_float, library_data_t ::real_float):
26232623 {
2624- detail::gemm_batch_impl<oneapi::mkl ::bfloat16, oneapi::mkl ::bfloat16, float ,
2624+ detail::gemm_batch_impl<oneapi::math ::bfloat16, oneapi::math ::bfloat16, float ,
26252625 float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
26262626 stride_a, b, ldb, stride_b, beta, c, ldc,
26272627 stride_c, batch_size);
0 commit comments