@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8282 return device_type.str ();
8383}
8484
85+ template <typename Ts> struct matrix_info_t {
86+ oneapi::mkl::transpose transpose_info[2 ];
87+ Ts value_info[2 ];
88+ std::int64_t size_info[3 ];
89+ std::int64_t ld_info[3 ];
90+ std::int64_t groupsize_info;
91+ };
92+
8593namespace dpct
8694{
8795 typedef sycl::queue *queue_ptr;
@@ -1737,26 +1745,13 @@ namespace dpct
17371745 };
17381746
17391747 template <class Ta , class Tb , class Tc , class Ts >
1740- inline void gemm_batch_impl (sycl::queue &q, oneapi::mkl::transpose a_trans,
1741- oneapi::mkl::transpose b_trans, int m, int n, int k,
1742- const void *alpha, const void **a, int lda,
1743- const void **b, int ldb, const void *beta, void **c,
1744- int ldc, int batch_size)
1745- {
1746- struct matrix_info_t
1747- {
1748- oneapi::mkl::transpose transpose_info[2 ];
1749- Ts value_info[2 ];
1750- std::int64_t size_info[3 ];
1751- std::int64_t ld_info[3 ];
1752- std::int64_t groupsize_info;
1753- };
1754-
1748+ inline void gemm_batch_impl (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1749+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1750+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
1751+ matrix_info_t <float > * matrix_info) {
17551752 Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
17561753 Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
17571754
1758- matrix_info_t *matrix_info =
1759- (matrix_info_t *)std::malloc (sizeof (matrix_info_t ));
17601755 matrix_info->transpose_info [0 ] = a_trans;
17611756 matrix_info->transpose_info [1 ] = b_trans;
17621757 matrix_info->value_info [0 ] = alpha_value;
@@ -1773,23 +1768,18 @@ namespace dpct
17731768 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17741769 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
17751770 matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1776- matrix_info->size_info + 2 , matrix_info-> value_info , reinterpret_cast <const Ta **>(a ),
1777- matrix_info-> ld_info , reinterpret_cast <const Tb **>(b ), matrix_info->ld_info + 1 ,
1778- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1779- &(matrix_info->groupsize_info ));
1771+ matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info-> value_info ),
1772+ reinterpret_cast <const Ta **>(a ), matrix_info->ld_info , reinterpret_cast < const Tb **>(b) ,
1773+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ) ,
1774+ reinterpret_cast <Tc **>(c), matrix_info-> ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17801775#else
17811776 sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17821777 q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1783- matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
1778+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts *>( matrix_info->value_info ) ,
17841779 reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1785- matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c ),
1786- matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1780+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ),
1781+ reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
17871782#endif
1788-
1789- q.submit ([&](sycl::handler &cgh)
1790- {
1791- cgh.depends_on (e);
1792- cgh.host_task ([=] { std::free (matrix_info); }); });
17931783 }
17941784
17951785 template <class Ta , class Tb , class Tc , class Ts >
@@ -2427,25 +2417,11 @@ namespace dpct
24272417 // / \param [in] ldc Leading dimension of C.
24282418 // / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24292419 // / \param [in] scaling_type Data type of the scaling factors.
2430- inline void gemm_batch (sycl::queue &q, oneapi::mkl::transpose a_trans,
2431- oneapi::mkl::transpose b_trans, int m, int n, int k,
2432- const void *alpha, const void *a[],
2433- library_data_t a_type, int lda, const void *b[],
2434- library_data_t b_type, int ldb, const void *beta,
2435- void *c[], library_data_t c_type, int ldc,
2436- int batch_size, library_data_t scaling_type)
2437- {
2438- if (scaling_type == library_data_t ::real_float &&
2439- c_type == library_data_t ::complex_float)
2440- {
2441- scaling_type = library_data_t ::complex_float;
2442- }
2443- else if (scaling_type == library_data_t ::real_double &&
2444- c_type == library_data_t ::complex_double)
2445- {
2446- scaling_type = library_data_t ::complex_double;
2447- }
2448-
2420+ inline void gemm_batch (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2421+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2422+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2423+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2424+ matrix_info_t <float > * matrix_info) {
24492425 std::uint64_t key =
24502426 detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
24512427 switch (key)
@@ -2454,68 +2430,41 @@ namespace dpct
24542430 library_data_t ::real_float, library_data_t ::real_float,
24552431 library_data_t ::real_float, library_data_t ::real_float):
24562432 {
2457- detail::gemm_batch_impl<float , float , float , float >(
2458- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2459- batch_size);
2433+ detail::gemm_batch_impl<float , float , float , float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2434+ beta, c, ldc, batch_size, matrix_info);
24602435 break ;
24612436 }
24622437 case detail::get_type_combination_id (
24632438 library_data_t ::real_double, library_data_t ::real_double,
24642439 library_data_t ::real_double, library_data_t ::real_double):
24652440 {
2466- detail::gemm_batch_impl<double , double , double , double >(
2467- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2468- batch_size);
2469- break ;
2470- }
2471- case detail::get_type_combination_id (
2472- library_data_t ::complex_float, library_data_t ::complex_float,
2473- library_data_t ::complex_float, library_data_t ::complex_float):
2474- {
2475- detail::gemm_batch_impl<std::complex <float >, std::complex <float >,
2476- std::complex <float >, std::complex <float >>(
2477- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2478- batch_size);
2479- break ;
2480- }
2481- case detail::get_type_combination_id (
2482- library_data_t ::complex_double, library_data_t ::complex_double,
2483- library_data_t ::complex_double, library_data_t ::complex_double):
2484- {
2485- detail::gemm_batch_impl<std::complex <double >, std::complex <double >,
2486- std::complex <double >, std::complex <double >>(
2487- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2488- batch_size);
2441+ detail::gemm_batch_impl<double , double , double , double >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2442+ beta, c, ldc, batch_size, matrix_info);
24892443 break ;
24902444 }
24912445 case detail::get_type_combination_id (
24922446 library_data_t ::real_half, library_data_t ::real_half,
24932447 library_data_t ::real_half, library_data_t ::real_half):
24942448 {
2495- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2496- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2497- a, lda, b, ldb, beta, c, ldc,
2498- batch_size);
2449+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2450+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24992451 break ;
25002452 }
25012453#ifdef __INTEL_MKL__
25022454 case detail::get_type_combination_id (
25032455 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
25042456 library_data_t ::real_bfloat16, library_data_t ::real_float):
25052457 {
2506- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2507- oneapi::mkl::bfloat16, float >(
2508- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2509- batch_size);
2458+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float >(
2459+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25102460 break ;
25112461 }
25122462 case detail::get_type_combination_id (
25132463 library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
25142464 library_data_t ::real_float, library_data_t ::real_float):
25152465 {
2516- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
2517- float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2518- b, ldb, beta, c, ldc, batch_size);
2466+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float , float >(
2467+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25192468 break ;
25202469 }
25212470 case detail::get_type_combination_id (
@@ -2526,28 +2475,25 @@ namespace dpct
25262475 dpct::get_value (reinterpret_cast <const std::int32_t *>(alpha), q);
25272476 float beta_float =
25282477 dpct::get_value (reinterpret_cast <const std::int32_t *>(beta), q);
2529- detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
2530- float >(q, a_trans, b_trans, m, n, k, &alpha_float,
2531- a, lda, b, ldb, &beta_float, c, ldc,
2532- batch_size);
2478+ detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t , float >(
2479+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2480+ matrix_info);
25332481 break ;
25342482 }
25352483 case detail::get_type_combination_id (
25362484 library_data_t ::real_int8, library_data_t ::real_int8,
25372485 library_data_t ::real_float, library_data_t ::real_float):
25382486 {
25392487 detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
2540- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2541- batch_size);
2488+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25422489 break ;
25432490 }
25442491 case detail::get_type_combination_id (
25452492 library_data_t ::real_half, library_data_t ::real_half,
25462493 library_data_t ::real_float, library_data_t ::real_float):
25472494 {
25482495 detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
2549- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2550- batch_size);
2496+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25512497 break ;
25522498 }
25532499#endif
@@ -2562,8 +2508,7 @@ namespace dpct
25622508 sycl::half alpha_half (alpha_value);
25632509 sycl::half beta_half (beta_value);
25642510 detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2565- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2566- batch_size);
2511+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
25672512 break ;
25682513 }
25692514 default :
0 commit comments