Skip to content

Commit 3577bc0

Browse files
committed
Rename occurences to mkl
1 parent dccc9f8 commit 3577bc0

File tree

4 files changed

+59
-72
lines changed

4 files changed

+59
-72
lines changed

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ if (NOT oneMath_FOUND)
100100
add_library(ONEMATH::${target} ALIAS ${target})
101101
endif()
102102
endfunction()
103+
onemath_alias(onemath)
103104
onemath_alias(onemath_blas_mklcpu)
104105
onemath_alias(onemath_blas_mklgpu)
105106
onemath_alias(onemath_blas_cublas)

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <sycl/sycl.hpp>
1717
#include <sycl/half_type.hpp>
1818
#include <syclcompat/math.hpp>
19-
#include <oneapi/mkl.hpp>
19+
#include <oneapi/math.hpp>
2020
#include <map>
2121

2222
#include "ggml.h"
@@ -83,13 +83,36 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8383
}
8484

8585
template <typename Ts> struct matrix_info_t {
86-
oneapi::mkl::transpose transpose_info[2];
86+
oneapi::math::transpose transpose_info[2];
8787
Ts value_info[2];
8888
std::int64_t size_info[3];
8989
std::int64_t ld_info[3];
9090
std::int64_t groupsize_info;
9191
};
9292

93+
inline auto get_onemath_backend(sycl::queue& queue)
94+
#ifdef GGML_SYCL_GENERIC
95+
-> sycl::queue&
96+
#endif
97+
{
98+
// If the backend is known at compile-time, use oneMath backend_selector to use
99+
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
100+
// fallback to runtime dispatching.
101+
#if defined(GGML_SYCL_INTEL_CPU)
102+
return oneapi::math::backend_selector<oneapi::math::backend::mklcpu>{ queue };
103+
#elif defined(GGML_SYCL_INTEL_GPU)
104+
return oneapi::math::backend_selector<oneapi::math::backend::mklgpu>{ queue };
105+
#elif defined(GGML_SYCL_NVIDIA)
106+
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
107+
#elif defined(GGML_SYCL_AMD)
108+
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
109+
#elif defined(GGML_SYCL_GENERIC)
110+
return queue;
111+
#else
112+
static_assert(false, "Unsupported backend");
113+
#endif
114+
}
115+
93116
namespace dpct
94117
{
95118
typedef sycl::queue *queue_ptr;
@@ -1687,8 +1710,8 @@ namespace dpct
16871710
namespace detail
16881711
{
16891712
template <class Ta, class Tb, class Tc, class Ts>
1690-
inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1691-
oneapi::mkl::transpose b_trans, int m, int n, int k,
1713+
inline void gemm_impl(sycl::queue &q, oneapi::math::transpose a_trans,
1714+
oneapi::math::transpose b_trans, int m, int n, int k,
16921715
const void *alpha, const void *a, int lda, const void *b,
16931716
int ldb, const void *beta, void *c, int ldc)
16941717
{
@@ -1697,14 +1720,8 @@ namespace dpct
16971720
auto data_a = get_memory<const Ta>(a);
16981721
auto data_b = get_memory<const Tb>(b);
16991722
auto data_c = get_memory<Tc>(c);
1700-
#ifdef GGML_SYCL_NVIDIA
1701-
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1702-
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1703-
beta_value, data_c, ldc);
1704-
#else
1705-
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1723+
oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
17061724
beta_value, data_c, ldc);
1707-
#endif
17081725
}
17091726

17101727
template <typename VecT, class BinaryOperation, class = void>
@@ -1735,7 +1752,7 @@ namespace dpct
17351752
};
17361753

17371754
template <class Ta, class Tb, class Tc, class Ts>
1738-
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1755+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
17391756
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
17401757
int ldb, const void * beta, void ** c, int ldc, int batch_size,
17411758
matrix_info_t<float> * matrix_info) {
@@ -1754,28 +1771,18 @@ namespace dpct
17541771
matrix_info->ld_info[2] = ldc;
17551772
matrix_info->groupsize_info = batch_size;
17561773

1757-
#ifdef GGML_SYCL_NVIDIA
1758-
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1759-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1760-
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1761-
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1762-
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1763-
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1764-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1765-
#else
1766-
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1767-
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1774+
sycl::event e = oneapi::math::blas::column_major::gemm_batch(
1775+
get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
17681776
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
17691777
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
17701778
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
17711779
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1772-
#endif
17731780
}
17741781

17751782
template <class Ta, class Tb, class Tc, class Ts>
17761783
inline void
1777-
gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1778-
oneapi::mkl::transpose b_trans, int m, int n,
1784+
gemm_batch_impl(sycl::queue &q, oneapi::math::transpose a_trans,
1785+
oneapi::math::transpose b_trans, int m, int n,
17791786
int k, const void *alpha, const void *a, int lda,
17801787
long long int stride_a, const void *b, int ldb,
17811788
long long int stride_b, const void *beta, void *c,
@@ -1786,16 +1793,9 @@ namespace dpct
17861793
auto data_a = get_memory<const Ta>(a);
17871794
auto data_b = get_memory<const Tb>(b);
17881795
auto data_c = get_memory<Tc>(c);
1789-
#ifdef GGML_SYCL_NVIDIA
1790-
oneapi::mkl::blas::column_major::gemm_batch(
1791-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1792-
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1793-
batch_size);
1794-
#else
1795-
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1796+
oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
17961797
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
17971798
stride_c, batch_size);
1798-
#endif
17991799
}
18001800

18011801
} // namespace detail
@@ -2259,8 +2259,8 @@ namespace dpct
22592259
sycl::range<3>(x, y, 1), direction);
22602260
}
22612261

2262-
inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans,
2263-
oneapi::mkl::transpose b_trans, int m, int n, int k,
2262+
inline void gemm(sycl::queue &q, oneapi::math::transpose a_trans,
2263+
oneapi::math::transpose b_trans, int m, int n, int k,
22642264
const void *alpha, const void *a, library_data_t a_type,
22652265
int lda, const void *b, library_data_t b_type, int ldb,
22662266
const void *beta, void *c, library_data_t c_type, int ldc,
@@ -2329,7 +2329,7 @@ namespace dpct
23292329
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
23302330
library_data_t::real_float, library_data_t::real_float):
23312331
{
2332-
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2332+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float,
23332333
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
23342334
ldb, beta, c, ldc);
23352335
break;
@@ -2369,8 +2369,8 @@ namespace dpct
23692369
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
23702370
library_data_t::real_bfloat16, library_data_t::real_float):
23712371
{
2372-
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2373-
oneapi::mkl::bfloat16, float>(
2372+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16,
2373+
oneapi::math::bfloat16, float>(
23742374
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
23752375
break;
23762376
}
@@ -2412,7 +2412,7 @@ namespace dpct
24122412
/// \param [in] ldc Leading dimension of C.
24132413
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24142414
/// \param [in] scaling_type Data type of the scaling factors.
2415-
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2415+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
24162416
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
24172417
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
24182418
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2450,15 +2450,15 @@ namespace dpct
24502450
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
24512451
library_data_t::real_bfloat16, library_data_t::real_float):
24522452
{
2453-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2453+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
24542454
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24552455
break;
24562456
}
24572457
case detail::get_type_combination_id(
24582458
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
24592459
library_data_t::real_float, library_data_t::real_float):
24602460
{
2461-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2461+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
24622462
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24632463
break;
24642464
}
@@ -2534,8 +2534,8 @@ namespace dpct
25342534
/// \param [in] stride_c Stride between the different C matrices.
25352535
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
25362536
/// \param [in] scaling_type Data type of the scaling factors.
2537-
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
2538-
oneapi::mkl::transpose b_trans, int m, int n, int k,
2537+
inline void gemm_batch(sycl::queue &q, oneapi::math::transpose a_trans,
2538+
oneapi::math::transpose b_trans, int m, int n, int k,
25392539
const void *alpha, const void *a, library_data_t a_type,
25402540
int lda, long long int stride_a, const void *b,
25412541
library_data_t b_type, int ldb, long long int stride_b,
@@ -2611,8 +2611,8 @@ namespace dpct
26112611
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
26122612
library_data_t::real_bfloat16, library_data_t::real_float):
26132613
{
2614-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2615-
oneapi::mkl::bfloat16, float>(
2614+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16,
2615+
oneapi::math::bfloat16, float>(
26162616
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
26172617
beta, c, ldc, stride_c, batch_size);
26182618
break;
@@ -2621,7 +2621,7 @@ namespace dpct
26212621
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
26222622
library_data_t::real_float, library_data_t::real_float):
26232623
{
2624-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2624+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float,
26252625
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
26262626
stride_a, b, ldb, stride_b, beta, c, ldc,
26272627
stride_c, batch_size);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,8 +2442,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
24422442
const sycl::half alpha_f16 = 1.0f;
24432443
const sycl::half beta_f16 = 0.0f;
24442444
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2445-
*stream, oneapi::mkl::transpose::trans,
2446-
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2445+
*stream, oneapi::math::transpose::trans,
2446+
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
24472447
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
24482448
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
24492449
dst_f16.get(), dpct::library_data_t::real_half, ldc,
@@ -2480,17 +2480,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
24802480
#if !GGML_SYCL_DNNL
24812481
const float alpha = 1.0f;
24822482
const float beta = 0.0f;
2483-
# ifdef GGML_SYCL_NVIDIA
2484-
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2485-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2486-
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2487-
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2488-
# else
2489-
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2490-
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2483+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2484+
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
24912485
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
24922486
dst_dd_i, ldc)));
2493-
# endif
24942487
#else
24952488
auto dnnl_stream = ctx.stream_dnnl(stream);
24962489
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
@@ -3250,8 +3243,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
32503243
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
32513244
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
32523245
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3253-
*main_stream, oneapi::mkl::transpose::trans,
3254-
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3246+
*main_stream, oneapi::math::transpose::trans,
3247+
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
32553248
(const char *)src0_as_f16, dpct::library_data_t::real_half,
32563249
nb01 / nb00, nb02 / nb00,
32573250
(const char *)src1_f16, dpct::library_data_t::real_half,
@@ -3292,7 +3285,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
32923285
});
32933286
}
32943287
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3295-
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3288+
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
32963289
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
32973290
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
32983291
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));

ggml/src/ggml-sycl/outprod.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <sycl/sycl.hpp>
2-
#include <oneapi/mkl.hpp>
32
#include "outprod.hpp"
43

54

@@ -34,20 +33,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
3433

3534
// Handle transposition of src1
3635
const bool src1_T = ggml_is_transposed(src1);
37-
const oneapi::mkl::transpose src1_op =
38-
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
36+
const oneapi::math::transpose src1_op =
37+
src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
3938
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
4039

4140
try {
42-
// Perform matrix multiplication using oneMKL GEMM
43-
#ifdef GGML_SYCL_NVIDIA
44-
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
45-
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
46-
ne00, src1_d, ldb, beta, dst_d, ne0);
47-
#else
48-
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
41+
// Perform matrix multiplication using oneMath GEMM
42+
oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
4943
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
50-
#endif
5144
}
5245
catch (sycl::exception const& exc) {
5346
std::cerr << exc.what() << std::endl;

0 commit comments

Comments
 (0)