@@ -82,14 +82,12 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8282 return device_type.str ();
8383}
8484
85- template <typename Ts>
86- struct matrix_info_t
87- {
85+ template <typename Ts> struct matrix_info_t {
8886 oneapi::mkl::transpose transpose_info[2 ];
89- Ts value_info[2 ];
90- std::int64_t size_info[3 ];
91- std::int64_t ld_info[3 ];
92- std::int64_t groupsize_info;
87+ Ts value_info[2 ];
88+ std::int64_t size_info[3 ];
89+ std::int64_t ld_info[3 ];
90+ std::int64_t groupsize_info;
9391};
9492
9593namespace dpct
@@ -1737,13 +1735,10 @@ namespace dpct
17371735 };
17381736
17391737 template <class Ta , class Tb , class Tc , class Ts >
1740- inline void gemm_batch_impl (sycl::queue &q, oneapi::mkl::transpose a_trans,
1741- oneapi::mkl::transpose b_trans, int m, int n, int k,
1742- const void *alpha, const void **a, int lda,
1743- const void **b, int ldb, const void *beta, void **c,
1744- int ldc, int batch_size, matrix_info_t <float >* matrix_info)
1745- {
1746-
1738+ inline void gemm_batch_impl (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1739+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1740+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
1741+ matrix_info_t <float > * matrix_info) {
17471742 Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
17481743 Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
17491744
@@ -1763,19 +1758,18 @@ namespace dpct
17631758 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17641759 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
17651760 matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1766- matrix_info->size_info + 2 , reinterpret_cast <Ts*>(matrix_info->value_info ), reinterpret_cast < const Ta **>(a ),
1767- matrix_info-> ld_info , reinterpret_cast <const Tb **>(b ), matrix_info->ld_info + 1 ,
1768- reinterpret_cast <Ts*>( matrix_info->value_info + 1 ) , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1769- &(matrix_info->groupsize_info ));
1761+ matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info->value_info ),
1762+ reinterpret_cast <const Ta **>(a ), matrix_info->ld_info , reinterpret_cast < const Tb **>(b) ,
1763+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ) ,
1764+ reinterpret_cast <Tc **>(c), matrix_info-> ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17701765#else
17711766 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17721767 q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1773- matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts*>(matrix_info->value_info ),
1768+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info->value_info ),
17741769 reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1775- matrix_info->ld_info + 1 , reinterpret_cast <Ts*>(matrix_info->value_info + 1 ), reinterpret_cast <Tc **>(c ),
1776- matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1770+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>(matrix_info->value_info + 1 ),
1771+ reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17771772#endif
1778-
17791773 }
17801774
17811775 template <class Ta , class Tb , class Tc , class Ts >
@@ -2418,15 +2412,11 @@ namespace dpct
24182412 // / \param [in] ldc Leading dimension of C.
24192413 // / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24202414 // / \param [in] scaling_type Data type of the scaling factors.
2421- inline void gemm_batch (sycl::queue &q, oneapi::mkl::transpose a_trans,
2422- oneapi::mkl::transpose b_trans, int m, int n, int k,
2423- const void *alpha, const void *a[],
2424- library_data_t a_type, int lda, const void *b[],
2425- library_data_t b_type, int ldb, const void *beta,
2426- void *c[], library_data_t c_type, int ldc,
2427- int batch_size, library_data_t scaling_type,
2428- matrix_info_t <float >* matrix_info)
2429- {
2415+ inline void gemm_batch (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2416+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2417+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2418+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2419+ matrix_info_t <float > * matrix_info) {
24302420 std::uint64_t key =
24312421 detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
24322422 switch (key)
@@ -2435,48 +2425,41 @@ namespace dpct
24352425 library_data_t ::real_float, library_data_t ::real_float,
24362426 library_data_t ::real_float, library_data_t ::real_float):
24372427 {
2438- detail::gemm_batch_impl<float , float , float , float >(
2439- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2440- batch_size, matrix_info);
2428+ detail::gemm_batch_impl<float , float , float , float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2429+ beta, c, ldc, batch_size, matrix_info);
24412430 break ;
24422431 }
24432432 case detail::get_type_combination_id (
24442433 library_data_t ::real_double, library_data_t ::real_double,
24452434 library_data_t ::real_double, library_data_t ::real_double):
24462435 {
2447- detail::gemm_batch_impl<double , double , double , double >(
2448- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2449- batch_size, matrix_info);
2436+ detail::gemm_batch_impl<double , double , double , double >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437+ beta, c, ldc, batch_size, matrix_info);
24502438 break ;
24512439 }
24522440 case detail::get_type_combination_id (
24532441 library_data_t ::real_half, library_data_t ::real_half,
24542442 library_data_t ::real_half, library_data_t ::real_half):
24552443 {
2456- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2457- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2458- a, lda, b, ldb, beta, c, ldc,
2459- batch_size, matrix_info);
2444+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2445+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24602446 break ;
24612447 }
24622448#ifdef __INTEL_MKL__
24632449 case detail::get_type_combination_id (
24642450 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
24652451 library_data_t ::real_bfloat16, library_data_t ::real_float):
24662452 {
2467- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2468- oneapi::mkl::bfloat16, float >(
2469- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2470- batch_size, matrix_info);
2453+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float >(
2454+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24712455 break ;
24722456 }
24732457 case detail::get_type_combination_id (
24742458 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
24752459 library_data_t ::real_float, library_data_t ::real_float):
24762460 {
2477- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
2478- float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2479- b, ldb, beta, c, ldc, batch_size, matrix_info);
2461+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float , float >(
2462+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24802463 break ;
24812464 }
24822465#endif
@@ -2488,28 +2471,25 @@ namespace dpct
24882471 dpct::get_value (reinterpret_cast <const std::int32_t *>(alpha), q);
24892472 float beta_float =
24902473 dpct::get_value (reinterpret_cast <const std::int32_t *>(beta), q);
2491- detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
2492- float >(q, a_trans, b_trans, m, n, k, &alpha_float,
2493- a, lda, b, ldb, &beta_float, c, ldc,
2494- batch_size, matrix_info);
2474+ detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t , float >(
2475+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2476+ matrix_info);
24952477 break ;
24962478 }
24972479 case detail::get_type_combination_id (
24982480 library_data_t ::real_int8, library_data_t ::real_int8,
24992481 library_data_t ::real_float, library_data_t ::real_float):
25002482 {
25012483 detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
2502- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2503- batch_size, matrix_info);
2484+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25042485 break ;
25052486 }
25062487 case detail::get_type_combination_id (
25072488 library_data_t ::real_half, library_data_t ::real_half,
25082489 library_data_t ::real_float, library_data_t ::real_float):
25092490 {
25102491 detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
2511- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2512- batch_size, matrix_info);
2492+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25132493 break ;
25142494 }
25152495 case detail::get_type_combination_id (
@@ -2523,8 +2503,7 @@ namespace dpct
25232503 sycl::half alpha_half (alpha_value);
25242504 sycl::half beta_half (beta_value);
25252505 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2526- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2527- batch_size, matrix_info);
2506+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
25282507 break ;
25292508 }
25302509 default :
0 commit comments