@@ -60,7 +60,9 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
6060 const char *,
6161 const char *,
6262 char *,
63+ #if !defined(USE_ONEMKL_CUBLAS)
6364 const bool ,
65+ #endif // !USE_ONEMKL_CUBLAS
6466 const std::vector<sycl::event> &);
6567
6668static gemm_batch_impl_fn_ptr_t
@@ -83,7 +85,9 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
8385 const char *matrixA,
8486 const char *matrixB,
8587 char *resultC,
88+ #if !defined(USE_ONEMKL_CUBLAS)
8689 const bool is_row_major,
90+ #endif // !USE_ONEMKL_CUBLAS
8791 const std::vector<sycl::event> &depends)
8892{
8993 type_utils::validate_type_for_device<Tab>(exec_q);
@@ -311,6 +315,7 @@ std::tuple<sycl::event, sycl::event, bool>
311315 std::int64_t lda;
312316 std::int64_t ldb;
313317
318+ // cuBLAS supports only column-major storage
314319#if defined(USE_ONEMKL_CUBLAS)
315320 const bool is_row_major = false ;
316321
@@ -391,10 +396,17 @@ std::tuple<sycl::event, sycl::event, bool>
391396 const char *b_typeless_ptr = matrixB.get_data ();
392397 char *r_typeless_ptr = resultC.get_data ();
393398
399+ #if defined(USE_ONEMKL_CUBLAS)
400+ sycl::event gemm_batch_ev =
401+ gemm_batch_fn (exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
402+ strideb, stridec, transA, transB, a_typeless_ptr,
403+ b_typeless_ptr, r_typeless_ptr, depends);
404+ #else
394405 sycl::event gemm_batch_ev =
395406 gemm_batch_fn (exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
396407 strideb, stridec, transA, transB, a_typeless_ptr,
397408 b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
409+ #endif // USE_ONEMKL_CUBLAS
398410
399411 sycl::event args_ev = dpctl::utils::keep_args_alive (
400412 exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
0 commit comments