Skip to content

Commit 4482516

Browse files
committed
clang-format
1 parent 3577bc0 commit 4482516

File tree

3 files changed

+54
-77
lines changed

3 files changed

+54
-77
lines changed

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 44 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,20 +1709,18 @@ namespace dpct
17091709

17101710
namespace detail
17111711
{
1712-
template <class Ta, class Tb, class Tc, class Ts>
1713-
inline void gemm_impl(sycl::queue &q, oneapi::math::transpose a_trans,
1714-
oneapi::math::transpose b_trans, int m, int n, int k,
1715-
const void *alpha, const void *a, int lda, const void *b,
1716-
int ldb, const void *beta, void *c, int ldc)
1717-
{
1718-
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1719-
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1720-
auto data_a = get_memory<const Ta>(a);
1721-
auto data_b = get_memory<const Tb>(b);
1722-
auto data_c = get_memory<Tc>(c);
1723-
oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1724-
beta_value, data_c, ldc);
1725-
}
1712+
template <class Ta, class Tb, class Tc, class Ts>
1713+
inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
1714+
int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1715+
const void * beta, void * c, int ldc) {
1716+
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1717+
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1718+
auto data_a = get_memory<const Ta>(a);
1719+
auto data_b = get_memory<const Tb>(b);
1720+
auto data_c = get_memory<Tc>(c);
1721+
oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
1722+
lda, data_b, ldb, beta_value, data_c, ldc);
1723+
}
17261724

17271725
template <typename VecT, class BinaryOperation, class = void>
17281726
class vectorized_binary
@@ -1772,30 +1770,27 @@ namespace dpct
17721770
matrix_info->groupsize_info = batch_size;
17731771

17741772
sycl::event e = oneapi::math::blas::column_major::gemm_batch(
1775-
get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1776-
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1777-
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1778-
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1779-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1773+
get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
1774+
matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1775+
reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1776+
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1777+
reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
1778+
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17801779
}
17811780

17821781
template <class Ta, class Tb, class Tc, class Ts>
1783-
inline void
1784-
gemm_batch_impl(sycl::queue &q, oneapi::math::transpose a_trans,
1785-
oneapi::math::transpose b_trans, int m, int n,
1786-
int k, const void *alpha, const void *a, int lda,
1787-
long long int stride_a, const void *b, int ldb,
1788-
long long int stride_b, const void *beta, void *c,
1789-
int ldc, long long int stride_c, int batch_size)
1790-
{
1782+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1783+
int m, int n, int k, const void * alpha, const void * a, int lda,
1784+
long long int stride_a, const void * b, int ldb, long long int stride_b,
1785+
const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
17911786
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
17921787
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
17931788
auto data_a = get_memory<const Ta>(a);
17941789
auto data_b = get_memory<const Tb>(b);
17951790
auto data_c = get_memory<Tc>(c);
1796-
oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1797-
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1798-
stride_c, batch_size);
1791+
oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
1792+
data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1793+
data_c, ldc, stride_c, batch_size);
17991794
}
18001795

18011796
} // namespace detail
@@ -2259,13 +2254,10 @@ namespace dpct
22592254
sycl::range<3>(x, y, 1), direction);
22602255
}
22612256

2262-
inline void gemm(sycl::queue &q, oneapi::math::transpose a_trans,
2263-
oneapi::math::transpose b_trans, int m, int n, int k,
2264-
const void *alpha, const void *a, library_data_t a_type,
2265-
int lda, const void *b, library_data_t b_type, int ldb,
2266-
const void *beta, void *c, library_data_t c_type, int ldc,
2267-
library_data_t scaling_type)
2268-
{
2257+
inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
2258+
int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2259+
library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2260+
library_data_t scaling_type) {
22692261
if (scaling_type == library_data_t::real_float &&
22702262
c_type == library_data_t::complex_float)
22712263
{
@@ -2329,9 +2321,8 @@ namespace dpct
23292321
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
23302322
library_data_t::real_float, library_data_t::real_float):
23312323
{
2332-
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float,
2333-
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
2334-
ldb, beta, c, ldc);
2324+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2325+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
23352326
break;
23362327
}
23372328
case detail::get_type_combination_id(
@@ -2369,8 +2360,7 @@ namespace dpct
23692360
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
23702361
library_data_t::real_bfloat16, library_data_t::real_float):
23712362
{
2372-
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16,
2373-
oneapi::math::bfloat16, float>(
2363+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
23742364
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
23752365
break;
23762366
}
@@ -2390,7 +2380,7 @@ namespace dpct
23902380
default:
23912381
throw std::runtime_error("the combination of data type is unsupported");
23922382
}
2393-
} // gemm()
2383+
} // gemm()
23942384

23952385
/// Computes a batch of matrix-matrix product with general matrices.
23962386
/// \param [in] q The queue where the routine should be executed.
@@ -2534,15 +2524,11 @@ namespace dpct
25342524
/// \param [in] stride_c Stride between the different C matrices.
25352525
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
25362526
/// \param [in] scaling_type Data type of the scaling factors.
2537-
inline void gemm_batch(sycl::queue &q, oneapi::math::transpose a_trans,
2538-
oneapi::math::transpose b_trans, int m, int n, int k,
2539-
const void *alpha, const void *a, library_data_t a_type,
2540-
int lda, long long int stride_a, const void *b,
2541-
library_data_t b_type, int ldb, long long int stride_b,
2542-
const void *beta, void *c, library_data_t c_type,
2543-
int ldc, long long int stride_c, int batch_size,
2544-
library_data_t scaling_type)
2545-
{
2527+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2528+
int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2529+
long long int stride_a, const void * b, library_data_t b_type, int ldb,
2530+
long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
2531+
long long int stride_c, int batch_size, library_data_t scaling_type) {
25462532
if (scaling_type == library_data_t::real_float &&
25472533
c_type == library_data_t::complex_float)
25482534
{
@@ -2611,20 +2597,18 @@ namespace dpct
26112597
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
26122598
library_data_t::real_bfloat16, library_data_t::real_float):
26132599
{
2614-
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16,
2615-
oneapi::math::bfloat16, float>(
2616-
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2617-
beta, c, ldc, stride_c, batch_size);
2600+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2601+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2602+
batch_size);
26182603
break;
26192604
}
26202605
case detail::get_type_combination_id(
26212606
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
26222607
library_data_t::real_float, library_data_t::real_float):
26232608
{
2624-
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float,
2625-
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2626-
stride_a, b, ldb, stride_b, beta, c, ldc,
2627-
stride_c, batch_size);
2609+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2610+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2611+
batch_size);
26282612
break;
26292613
}
26302614
#endif

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,9 +2481,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
24812481
const float alpha = 1.0f;
24822482
const float beta = 0.0f;
24832483
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2484-
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2485-
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2486-
dst_dd_i, ldc)));
2484+
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2485+
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2486+
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
24872487
#else
24882488
auto dnnl_stream = ctx.stream_dnnl(stream);
24892489
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
@@ -3243,14 +3243,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
32433243
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
32443244
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
32453245
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3246-
*main_stream, oneapi::math::transpose::trans,
3247-
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
3248-
(const char *)src0_as_f16, dpct::library_data_t::real_half,
3249-
nb01 / nb00, nb02 / nb00,
3250-
(const char *)src1_f16, dpct::library_data_t::real_half,
3251-
nb11 / nb10, nb12 / nb10, beta,
3252-
(char *)dst_t, cu_data_type, ne01, nb2 / nb0,
3253-
ne12 * ne13, cu_compute_type)));
3246+
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
3247+
(const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
3248+
(const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
3249+
cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
32543250
} else {
32553251
const int ne23 = ne12*ne13;
32563252

ggml/src/ggml-sycl/outprod.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
#include <sycl/sycl.hpp>
21
#include "outprod.hpp"
32

4-
53
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
64
const ggml_tensor *src0 = dst->src[0];
75
const ggml_tensor *src1 = dst->src[1];
@@ -33,14 +31,13 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
3331

3432
// Handle transposition of src1
3533
const bool src1_T = ggml_is_transposed(src1);
36-
const oneapi::math::transpose src1_op =
37-
src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
34+
const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
3835
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
3936

4037
try {
4138
// Perform matrix multiplication using oneMath GEMM
42-
oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
43-
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
39+
oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
40+
ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
4441
}
4542
catch (sycl::exception const& exc) {
4643
std::cerr << exc.what() << std::endl;

0 commit comments

Comments
 (0)