@@ -1689,13 +1689,14 @@ namespace dpct
16891689 auto data_a = get_memory<const Ta>(a);
16901690 auto data_b = get_memory<const Tb>(b);
16911691 auto data_c = get_memory<Tc>(c);
1692- oneapi::mkl::blas::column_major::gemm (
16931692#ifdef GGML_SYCL_NVIDIA
1694- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1693+ oneapi::mkl::blas::column_major::gemm (oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1694+ a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1695+ beta_value, data_c, ldc);
16951696#else
1696- q,
1697+ oneapi::mkl::blas::column_major::gemm (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1698+ beta_value, data_c, ldc);
16971699#endif
1698- a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc);
16991700 }
17001701
17011702 template <typename VecT, class BinaryOperation , class = void >
@@ -1758,17 +1759,22 @@ namespace dpct
17581759 matrix_info->ld_info [2 ] = ldc;
17591760 matrix_info->groupsize_info = batch_size;
17601761
1761- sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17621762#ifdef GGML_SYCL_NVIDIA
1763- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1763+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1764+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1765+ matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1766+ matrix_info->size_info + 2 , matrix_info->value_info , reinterpret_cast <const Ta **>(a),
1767+ matrix_info->ld_info , reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1768+ matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1769+ &(matrix_info->groupsize_info ));
17641770#else
1765- q,
1766- #endif
1767- matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1771+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1772+ q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
17681773 matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
17691774 reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
17701775 matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
17711776 matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1777+ #endif
17721778
17731779 q.submit ([&](sycl::handler &cgh)
17741780 {
@@ -1790,14 +1796,16 @@ namespace dpct
17901796 auto data_a = get_memory<const Ta>(a);
17911797 auto data_b = get_memory<const Tb>(b);
17921798 auto data_c = get_memory<Tc>(c);
1793- oneapi::mkl::blas::column_major::gemm_batch (
17941799#ifdef GGML_SYCL_NVIDIA
1795- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1800+ oneapi::mkl::blas::column_major::gemm_batch (
1801+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1802+ alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1803+ batch_size);
17961804#else
1797- q,
1805+ oneapi::mkl::blas::column_major::gemm_batch (q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1806+ stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1807+ stride_c, batch_size);
17981808#endif
1799- a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1800- data_c, ldc, stride_c, batch_size);
18011809 }
18021810
18031811 } // namespace detail
0 commit comments