@@ -1709,20 +1709,18 @@ namespace dpct
17091709
17101710 namespace detail
17111711 {
1712- template <class Ta , class Tb , class Tc , class Ts >
1713- inline void gemm_impl (sycl::queue &q, oneapi::math::transpose a_trans,
1714- oneapi::math::transpose b_trans, int m, int n, int k,
1715- const void *alpha, const void *a, int lda, const void *b,
1716- int ldb, const void *beta, void *c, int ldc)
1717- {
1718- Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
1719- Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
1720- auto data_a = get_memory<const Ta>(a);
1721- auto data_b = get_memory<const Tb>(b);
1722- auto data_c = get_memory<Tc>(c);
1723- oneapi::math::blas::column_major::gemm (get_onemath_backend (q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1724- beta_value, data_c, ldc);
1725- }
1712+ template <class Ta , class Tb , class Tc , class Ts >
1713+ inline void gemm_impl (sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
1714+ int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1715+ const void * beta, void * c, int ldc) {
1716+ Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
1717+ Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
1718+ auto data_a = get_memory<const Ta>(a);
1719+ auto data_b = get_memory<const Tb>(b);
1720+ auto data_c = get_memory<Tc>(c);
1721+ oneapi::math::blas::column_major::gemm (get_onemath_backend (q), a_trans, b_trans, m, n, k, alpha_value, data_a,
1722+ lda, data_b, ldb, beta_value, data_c, ldc);
1723+ }
17261724
17271725 template <typename VecT, class BinaryOperation , class = void >
17281726 class vectorized_binary
@@ -1772,30 +1770,27 @@ namespace dpct
17721770 matrix_info->groupsize_info = batch_size;
17731771
17741772 sycl::event e = oneapi::math::blas::column_major::gemm_batch (
1775- get_onemath_backend (q), matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1776- matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info->value_info ),
1777- reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1778- matrix_info->ld_info + 1 , reinterpret_cast <Ts *>(matrix_info->value_info + 1 ),
1779- reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1773+ get_onemath_backend (q), matrix_info->transpose_info , matrix_info->transpose_info + 1 ,
1774+ matrix_info->size_info , matrix_info->size_info + 1 , matrix_info->size_info + 2 ,
1775+ reinterpret_cast <Ts *>(matrix_info->value_info ), reinterpret_cast <const Ta **>(a), matrix_info->ld_info ,
1776+ reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1777+ reinterpret_cast <Ts *>(matrix_info->value_info + 1 ), reinterpret_cast <Tc **>(c),
1778+ matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17801779 }
17811780
17821781 template <class Ta , class Tb , class Tc , class Ts >
1783- inline void
1784- gemm_batch_impl (sycl::queue &q, oneapi::math::transpose a_trans,
1785- oneapi::math::transpose b_trans, int m, int n,
1786- int k, const void *alpha, const void *a, int lda,
1787- long long int stride_a, const void *b, int ldb,
1788- long long int stride_b, const void *beta, void *c,
1789- int ldc, long long int stride_c, int batch_size)
1790- {
1782+ inline void gemm_batch_impl (sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1783+ int m, int n, int k, const void * alpha, const void * a, int lda,
1784+ long long int stride_a, const void * b, int ldb, long long int stride_b,
1785+ const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
17911786 Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
17921787 Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
17931788 auto data_a = get_memory<const Ta>(a);
17941789 auto data_b = get_memory<const Tb>(b);
17951790 auto data_c = get_memory<Tc>(c);
1796- oneapi::math::blas::column_major::gemm_batch (get_onemath_backend (q), a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1797- stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc ,
1798- stride_c, batch_size);
1791+ oneapi::math::blas::column_major::gemm_batch (get_onemath_backend (q), a_trans, b_trans, m, n, k, alpha_value,
1792+ data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1793+ data_c, ldc, stride_c, batch_size);
17991794 }
18001795
18011796 } // namespace detail
@@ -2259,13 +2254,10 @@ namespace dpct
22592254 sycl::range<3 >(x, y, 1 ), direction);
22602255 }
22612256
2262- inline void gemm (sycl::queue &q, oneapi::math::transpose a_trans,
2263- oneapi::math::transpose b_trans, int m, int n, int k,
2264- const void *alpha, const void *a, library_data_t a_type,
2265- int lda, const void *b, library_data_t b_type, int ldb,
2266- const void *beta, void *c, library_data_t c_type, int ldc,
2267- library_data_t scaling_type)
2268- {
2257+ inline void gemm (sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
2258+ int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2259+ library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2260+ library_data_t scaling_type) {
22692261 if (scaling_type == library_data_t ::real_float &&
22702262 c_type == library_data_t ::complex_float)
22712263 {
@@ -2329,9 +2321,8 @@ namespace dpct
23292321 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
23302322 library_data_t ::real_float, library_data_t ::real_float):
23312323 {
2332- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float ,
2333- float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
2334- ldb, beta, c, ldc);
2324+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float , float >(
2325+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
23352326 break ;
23362327 }
23372328 case detail::get_type_combination_id (
@@ -2369,8 +2360,7 @@ namespace dpct
23692360 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
23702361 library_data_t ::real_bfloat16, library_data_t ::real_float):
23712362 {
2372- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16,
2373- oneapi::math::bfloat16, float >(
2363+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float >(
23742364 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
23752365 break ;
23762366 }
@@ -2390,7 +2380,7 @@ namespace dpct
23902380 default :
23912381 throw std::runtime_error (" the combination of data type is unsupported" );
23922382 }
2393- } // gemm()
2383+ } // gemm()
23942384
23952385 // / Computes a batch of matrix-matrix product with general matrices.
23962386 // / \param [in] q The queue where the routine should be executed.
@@ -2534,15 +2524,11 @@ namespace dpct
25342524 // / \param [in] stride_c Stride between the different C matrices.
25352525 // / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
25362526 // / \param [in] scaling_type Data type of the scaling factors.
2537- inline void gemm_batch (sycl::queue &q, oneapi::math::transpose a_trans,
2538- oneapi::math::transpose b_trans, int m, int n, int k,
2539- const void *alpha, const void *a, library_data_t a_type,
2540- int lda, long long int stride_a, const void *b,
2541- library_data_t b_type, int ldb, long long int stride_b,
2542- const void *beta, void *c, library_data_t c_type,
2543- int ldc, long long int stride_c, int batch_size,
2544- library_data_t scaling_type)
2545- {
2527+ inline void gemm_batch (sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2528+ int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2529+ long long int stride_a, const void * b, library_data_t b_type, int ldb,
2530+ long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
2531+ long long int stride_c, int batch_size, library_data_t scaling_type) {
25462532 if (scaling_type == library_data_t ::real_float &&
25472533 c_type == library_data_t ::complex_float)
25482534 {
@@ -2611,20 +2597,18 @@ namespace dpct
26112597 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
26122598 library_data_t ::real_bfloat16, library_data_t ::real_float):
26132599 {
2614- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16,
2615- oneapi::math::bfloat16, float >(
2616- q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2617- beta, c, ldc, stride_c, batch_size);
2600+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float >(
2601+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2602+ batch_size);
26182603 break ;
26192604 }
26202605 case detail::get_type_combination_id (
26212606 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
26222607 library_data_t ::real_float, library_data_t ::real_float):
26232608 {
2624- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float ,
2625- float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2626- stride_a, b, ldb, stride_b, beta, c, ldc,
2627- stride_c, batch_size);
2609+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float , float >(
2610+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2611+ batch_size);
26282612 break ;
26292613 }
26302614#endif
0 commit comments