@@ -1689,9 +1689,14 @@ namespace dpct
16891689 auto data_a = get_memory<const Ta>(a);
16901690 auto data_b = get_memory<const Tb>(b);
16911691 auto data_c = get_memory<Tc>(c);
1692- oneapi::mkl::blas::column_major::gemm (
1693- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1694- data_b, ldb, beta_value, data_c, ldc);
1692+ #ifdef GGML_SYCL_NVIDIA
1693+ oneapi::mkl::blas::column_major::gemm (oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1694+ a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1695+ beta_value, data_c, ldc);
1696+ #else
1697+ oneapi::mkl::blas::column_major::gemm (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1698+ beta_value, data_c, ldc);
1699+ #endif
16951700 }
16961701
16971702 template <typename VecT, class BinaryOperation , class = void >
@@ -1754,14 +1759,22 @@ namespace dpct
17541759 matrix_info->ld_info [2 ] = ldc;
17551760 matrix_info->groupsize_info = batch_size;
17561761
1762+ #ifdef GGML_SYCL_NVIDIA
1763+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1764+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1765+ matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1766+ matrix_info->size_info + 2 , matrix_info->value_info , reinterpret_cast <const Ta **>(a),
1767+ matrix_info->ld_info , reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1768+ matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1769+ &(matrix_info->groupsize_info ));
1770+ #else
17571771 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1758- q, matrix_info->transpose_info , matrix_info->transpose_info + 1 ,
1759- matrix_info->size_info , matrix_info->size_info + 1 ,
1760- matrix_info->size_info + 2 , matrix_info->value_info ,
1761- reinterpret_cast <const Ta **>(a), matrix_info->ld_info ,
1762- reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1763- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
1772+ q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1773+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
1774+ reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1775+ matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
17641776 matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1777+ #endif
17651778
17661779 q.submit ([&](sycl::handler &cgh)
17671780 {
@@ -1783,10 +1796,16 @@ namespace dpct
17831796 auto data_a = get_memory<const Ta>(a);
17841797 auto data_b = get_memory<const Tb>(b);
17851798 auto data_c = get_memory<Tc>(c);
1799+ #ifdef GGML_SYCL_NVIDIA
17861800 oneapi::mkl::blas::column_major::gemm_batch (
1787- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1788- stride_a, data_b, ldb, stride_b, beta_value,
1789- data_c, ldc, stride_c, batch_size);
1801+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1802+ alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1803+ batch_size);
1804+ #else
1805+ oneapi::mkl::blas::column_major::gemm_batch (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1806+ stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1807+ stride_c, batch_size);
1808+ #endif
17901809 }
17911810
17921811 } // namespace detail
0 commit comments