Skip to content

Commit f6e6fc4

Browse files
committed
Address PR comments to increase readibility
1 parent ffd0a99 commit f6e6fc4

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,13 +1689,14 @@ namespace dpct
16891689
auto data_a = get_memory<const Ta>(a);
16901690
auto data_b = get_memory<const Tb>(b);
16911691
auto data_c = get_memory<Tc>(c);
1692-
oneapi::mkl::blas::column_major::gemm(
16931692
#ifdef GGML_SYCL_NVIDIA
1694-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1693+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1694+
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1695+
beta_value, data_c, ldc);
16951696
#else
1696-
q,
1697+
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1698+
beta_value, data_c, ldc);
16971699
#endif
1698-
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc);
16991700
}
17001701

17011702
template <typename VecT, class BinaryOperation, class = void>
@@ -1758,17 +1759,22 @@ namespace dpct
17581759
matrix_info->ld_info[2] = ldc;
17591760
matrix_info->groupsize_info = batch_size;
17601761

1761-
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17621762
#ifdef GGML_SYCL_NVIDIA
1763-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1763+
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1764+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1765+
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1766+
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1767+
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1768+
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1769+
&(matrix_info->groupsize_info));
17641770
#else
1765-
q,
1766-
#endif
1767-
matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1771+
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1772+
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
17681773
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
17691774
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
17701775
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
17711776
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1777+
#endif
17721778

17731779
q.submit([&](sycl::handler &cgh)
17741780
{
@@ -1790,14 +1796,16 @@ namespace dpct
17901796
auto data_a = get_memory<const Ta>(a);
17911797
auto data_b = get_memory<const Tb>(b);
17921798
auto data_c = get_memory<Tc>(c);
1793-
oneapi::mkl::blas::column_major::gemm_batch(
17941799
#ifdef GGML_SYCL_NVIDIA
1795-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1800+
oneapi::mkl::blas::column_major::gemm_batch(
1801+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1802+
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1803+
batch_size);
17961804
#else
1797-
q,
1805+
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1806+
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1807+
stride_c, batch_size);
17981808
#endif
1799-
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1800-
data_c, ldc, stride_c, batch_size);
18011809
}
18021810

18031811
} // namespace detail

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,15 +2561,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
25612561
const float alpha = 1.0f;
25622562
const float beta = 0.0f;
25632563
#if !GGML_SYCL_DNNL
2564-
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
25652564
# ifdef GGML_SYCL_NVIDIA
2566-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
2565+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2566+
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2567+
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2568+
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
25672569
# else
2568-
*stream,
2569-
# endif
2570-
oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2570+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2571+
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
25712572
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
25722573
dst_dd_i, ldc)));
2574+
# endif
25732575
#else
25742576
auto dnnl_stream = ctx.stream_dnnl(stream);
25752577
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),

ggml/src/ggml-sycl/outprod.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
4040

4141
try {
4242
// Perform matrix multiplication using oneMKL GEMM
43-
oneapi::mkl::blas::column_major::gemm(
4443
#ifdef GGML_SYCL_NVIDIA
45-
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
44+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
45+
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
46+
ne00, src1_d, ldb, beta, dst_d, ne0);
4647
#else
47-
*stream,
48+
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
49+
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
4850
#endif
49-
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d,
50-
ne0);
5151
}
5252
catch (sycl::exception const& exc) {
5353
std::cerr << exc.what() << std::endl;

0 commit comments

Comments
 (0)