@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8282    return  device_type.str ();
8383}
8484
85+ template  <typename  Ts> struct  matrix_info_t  {
86+     oneapi::mkl::transpose transpose_info[2 ];
87+     Ts                     value_info[2 ];
88+     std::int64_t            size_info[3 ];
89+     std::int64_t            ld_info[3 ];
90+     std::int64_t            groupsize_info;
91+ };
92+ 
8593namespace  dpct 
8694{
8795    typedef  sycl::queue *queue_ptr;
@@ -1727,26 +1735,13 @@ namespace dpct
17271735        };
17281736
17291737        template  <class  Ta , class  Tb , class  Tc , class  Ts >
1730-         inline  void  gemm_batch_impl (sycl::queue &q, oneapi::mkl::transpose a_trans,
1731-                                     oneapi::mkl::transpose b_trans, int  m, int  n, int  k,
1732-                                     const  void  *alpha, const  void  **a, int  lda,
1733-                                     const  void  **b, int  ldb, const  void  *beta, void  **c,
1734-                                     int  ldc, int  batch_size)
1735-         {
1736-             struct  matrix_info_t 
1737-             {
1738-                 oneapi::mkl::transpose transpose_info[2 ];
1739-                 Ts value_info[2 ];
1740-                 std::int64_t  size_info[3 ];
1741-                 std::int64_t  ld_info[3 ];
1742-                 std::int64_t  groupsize_info;
1743-             };
1744- 
1738+         inline  void  gemm_batch_impl (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1739+                                     int  m, int  n, int  k, const  void  * alpha, const  void  ** a, int  lda, const  void  ** b,
1740+                                     int  ldb, const  void  * beta, void  ** c, int  ldc, int  batch_size,
1741+                                     matrix_info_t <float > * matrix_info) {
17451742            Ts alpha_value = dpct::get_value (reinterpret_cast <const  Ts *>(alpha), q);
17461743            Ts beta_value = dpct::get_value (reinterpret_cast <const  Ts *>(beta), q);
17471744
1748-             matrix_info_t  *matrix_info =
1749-                 (matrix_info_t  *)std::malloc (sizeof (matrix_info_t ));
17501745            matrix_info->transpose_info [0 ] = a_trans;
17511746            matrix_info->transpose_info [1 ] = b_trans;
17521747            matrix_info->value_info [0 ] = alpha_value;
@@ -1763,23 +1758,18 @@ namespace dpct
17631758            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17641759                oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
17651760                matrix_info->transpose_info  + 1 , matrix_info->size_info , matrix_info->size_info  + 1 ,
1766-                 matrix_info->size_info  + 2 , matrix_info-> value_info ,  reinterpret_cast <const  Ta **>(a ),
1767-                 matrix_info-> ld_info ,  reinterpret_cast <const  Tb  **>(b ), matrix_info->ld_info  +  1 ,
1768-                 matrix_info->value_info  + 1 , reinterpret_cast <Tc **>(c),  matrix_info->ld_info  + 2 ,  1 ,
1769-                 &(matrix_info->groupsize_info ));
1761+                 matrix_info->size_info  + 2 , reinterpret_cast <Ts *>(matrix_info-> value_info ),
1762+                 reinterpret_cast <const  Ta  **>(a ), matrix_info->ld_info ,  reinterpret_cast < const  Tb **>(b) ,
1763+                 matrix_info->ld_info  + 1 , reinterpret_cast <Ts *>( matrix_info->value_info  + 1 ) ,
1764+                 reinterpret_cast <Tc **>(c), matrix_info-> ld_info  +  2 ,  1 ,  &(matrix_info->groupsize_info ));
17701765#else 
17711766            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
17721767                q, matrix_info->transpose_info , matrix_info->transpose_info  + 1 , matrix_info->size_info ,
1773-                 matrix_info->size_info  + 1 , matrix_info->size_info  + 2 , matrix_info->value_info ,
1768+                 matrix_info->size_info  + 1 , matrix_info->size_info  + 2 , reinterpret_cast <Ts *>( matrix_info->value_info ) ,
17741769                reinterpret_cast <const  Ta **>(a), matrix_info->ld_info , reinterpret_cast <const  Tb **>(b),
1775-                 matrix_info->ld_info  + 1 , matrix_info->value_info  + 1 ,  reinterpret_cast <Tc **>(c ),
1776-                 matrix_info->ld_info  + 2 , 1 , &(matrix_info->groupsize_info ));
1770+                 matrix_info->ld_info  + 1 , reinterpret_cast <Ts *>( matrix_info->value_info  + 1 ),
1771+                 reinterpret_cast <Tc **>(c),  matrix_info->ld_info  + 2 , 1 , &(matrix_info->groupsize_info ));
17771772#endif 
1778- 
1779-             q.submit ([&](sycl::handler &cgh)
1780-                      {
1781-     cgh.depends_on (e);
1782-     cgh.host_task ([=] { std::free (matrix_info); }); });
17831773        }
17841774
17851775        template  <class  Ta , class  Tb , class  Tc , class  Ts >
@@ -2422,25 +2412,11 @@ namespace dpct
24222412    // / \param [in] ldc Leading dimension of C.
24232413    // / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24242414    // / \param [in] scaling_type Data type of the scaling factors.
2425-     inline  void  gemm_batch (sycl::queue &q, oneapi::mkl::transpose a_trans,
2426-                            oneapi::mkl::transpose b_trans, int  m, int  n, int  k,
2427-                            const  void  *alpha, const  void  *a[],
2428-                            library_data_t  a_type, int  lda, const  void  *b[],
2429-                            library_data_t  b_type, int  ldb, const  void  *beta,
2430-                            void  *c[], library_data_t  c_type, int  ldc,
2431-                            int  batch_size, library_data_t  scaling_type)
2432-     {
2433-         if  (scaling_type == library_data_t ::real_float &&
2434-             c_type == library_data_t ::complex_float)
2435-         {
2436-             scaling_type = library_data_t ::complex_float;
2437-         }
2438-         else  if  (scaling_type == library_data_t ::real_double &&
2439-                  c_type == library_data_t ::complex_double)
2440-         {
2441-             scaling_type = library_data_t ::complex_double;
2442-         }
2443- 
2415+     inline  void  gemm_batch (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int  m,
2416+                            int  n, int  k, const  void  * alpha, const  void  * a[], library_data_t  a_type, int  lda,
2417+                            const  void  * b[], library_data_t  b_type, int  ldb, const  void  * beta, void  * c[],
2418+                            library_data_t  c_type, int  ldc, int  batch_size, library_data_t  scaling_type,
2419+                            matrix_info_t <float > * matrix_info) {
24442420        std::uint64_t  key =
24452421            detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
24462422        switch  (key)
@@ -2449,68 +2425,41 @@ namespace dpct
24492425            library_data_t ::real_float, library_data_t ::real_float,
24502426            library_data_t ::real_float, library_data_t ::real_float):
24512427        {
2452-             detail::gemm_batch_impl<float , float , float , float >(
2453-                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2454-                 batch_size);
2428+             detail::gemm_batch_impl<float , float , float , float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2429+                                                                 beta, c, ldc, batch_size, matrix_info);
24552430            break ;
24562431        }
24572432        case  detail::get_type_combination_id (
24582433            library_data_t ::real_double, library_data_t ::real_double,
24592434            library_data_t ::real_double, library_data_t ::real_double):
24602435        {
2461-             detail::gemm_batch_impl<double , double , double , double >(
2462-                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2463-                 batch_size);
2464-             break ;
2465-         }
2466-         case  detail::get_type_combination_id (
2467-             library_data_t ::complex_float, library_data_t ::complex_float,
2468-             library_data_t ::complex_float, library_data_t ::complex_float):
2469-         {
2470-             detail::gemm_batch_impl<std::complex <float >, std::complex <float >,
2471-                                     std::complex <float >, std::complex <float >>(
2472-                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2473-                 batch_size);
2474-             break ;
2475-         }
2476-         case  detail::get_type_combination_id (
2477-             library_data_t ::complex_double, library_data_t ::complex_double,
2478-             library_data_t ::complex_double, library_data_t ::complex_double):
2479-         {
2480-             detail::gemm_batch_impl<std::complex <double >, std::complex <double >,
2481-                                     std::complex <double >, std::complex <double >>(
2482-                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2483-                 batch_size);
2436+             detail::gemm_batch_impl<double , double , double , double >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437+                                                                     beta, c, ldc, batch_size, matrix_info);
24842438            break ;
24852439        }
24862440        case  detail::get_type_combination_id (
24872441            library_data_t ::real_half, library_data_t ::real_half,
24882442            library_data_t ::real_half, library_data_t ::real_half):
24892443        {
2490-             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2491-                                     sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2492-                                                 a, lda, b, ldb, beta, c, ldc,
2493-                                                 batch_size);
2444+             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2445+                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24942446            break ;
24952447        }
24962448#ifdef  __INTEL_MKL__
24972449        case  detail::get_type_combination_id (
24982450            library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
24992451            library_data_t ::real_bfloat16, library_data_t ::real_float):
25002452        {
2501-             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2502-                                     oneapi::mkl::bfloat16, float >(
2503-                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2504-                 batch_size);
2453+             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float >(
2454+                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25052455            break ;
25062456        }
25072457        case  detail::get_type_combination_id (
25082458            library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
25092459            library_data_t ::real_float, library_data_t ::real_float):
25102460        {
2511-             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
2512-                                     float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2513-                                            b, ldb, beta, c, ldc, batch_size);
2461+             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float , float >(
2462+                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25142463            break ;
25152464        }
25162465#endif 
@@ -2522,28 +2471,25 @@ namespace dpct
25222471                dpct::get_value (reinterpret_cast <const  std::int32_t  *>(alpha), q);
25232472            float  beta_float =
25242473                dpct::get_value (reinterpret_cast <const  std::int32_t  *>(beta), q);
2525-             detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
2526-                                     float >(q, a_trans, b_trans, m, n, k, &alpha_float,
2527-                                            a, lda, b, ldb, &beta_float, c, ldc,
2528-                                            batch_size);
2474+             detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t , float >(
2475+                 q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2476+                 matrix_info);
25292477            break ;
25302478        }
25312479        case  detail::get_type_combination_id (
25322480            library_data_t ::real_int8, library_data_t ::real_int8,
25332481            library_data_t ::real_float, library_data_t ::real_float):
25342482        {
25352483            detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
2536-                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2537-                 batch_size);
2484+                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25382485            break ;
25392486        }
25402487        case  detail::get_type_combination_id (
25412488            library_data_t ::real_half, library_data_t ::real_half,
25422489            library_data_t ::real_float, library_data_t ::real_float):
25432490        {
25442491            detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
2545-                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2546-                 batch_size);
2492+                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25472493            break ;
25482494        }
25492495        case  detail::get_type_combination_id (
@@ -2557,8 +2503,7 @@ namespace dpct
25572503            sycl::half alpha_half (alpha_value);
25582504            sycl::half beta_half (beta_value);
25592505            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2560-                 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2561-                 batch_size);
2506+                 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
25622507            break ;
25632508        }
25642509        default :
0 commit comments