@@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
5454
5555int sgemm_direct_performant (BLASLONG M , BLASLONG N , BLASLONG K );
5656
57-
57+ int shgemm_beta (BLASLONG , BLASLONG , BLASLONG , float ,
58+ hfloat16 * , BLASLONG , hfloat16 * , BLASLONG , float * , BLASLONG );
5859int sbgemm_beta (BLASLONG , BLASLONG , BLASLONG , float ,
5960 bfloat16 * , BLASLONG , bfloat16 * , BLASLONG , float * , BLASLONG );
6061int sgemm_beta (BLASLONG , BLASLONG , BLASLONG , float ,
@@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
7879 xdouble * , BLASLONG , xdouble * , BLASLONG , xdouble * , BLASLONG );
7980#endif
8081
82+ int shgemm_incopy (BLASLONG m , BLASLONG n , hfloat16 * a , BLASLONG lda , hfloat16 * b );
83+ int shgemm_itcopy (BLASLONG m , BLASLONG n , hfloat16 * a , BLASLONG lda , hfloat16 * b );
84+ int shgemm_oncopy (BLASLONG m , BLASLONG n , hfloat16 * a , BLASLONG lda , hfloat16 * b );
85+ int shgemm_otcopy (BLASLONG m , BLASLONG n , hfloat16 * a , BLASLONG lda , hfloat16 * b );
8186int sbgemm_incopy (BLASLONG m , BLASLONG n , bfloat16 * a , BLASLONG lda , bfloat16 * b );
8287int sbgemm_itcopy (BLASLONG m , BLASLONG n , bfloat16 * a , BLASLONG lda , bfloat16 * b );
8388int sbgemm_oncopy (BLASLONG m , BLASLONG n , bfloat16 * a , BLASLONG lda , bfloat16 * b );
@@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
505510int xher2k_kernel_LN (BLASLONG m , BLASLONG n , BLASLONG k , xdouble alpha_r , xdouble alpha_i , xdouble * a , xdouble * b , xdouble * c , BLASLONG ldc , BLASLONG offset , int flag );
506511int xher2k_kernel_LC (BLASLONG m , BLASLONG n , BLASLONG k , xdouble alpha_r , xdouble alpha_i , xdouble * a , xdouble * b , xdouble * c , BLASLONG ldc , BLASLONG offset , int flag );
507512
513+ int shgemm_kernel (BLASLONG , BLASLONG , BLASLONG , float , hfloat16 * , hfloat16 * , float * , BLASLONG );
508514int sbgemm_kernel (BLASLONG , BLASLONG , BLASLONG , float , bfloat16 * , bfloat16 * , float * , BLASLONG );
509515int sgemm_kernel (BLASLONG , BLASLONG , BLASLONG , float , float * , float * , float * , BLASLONG );
510516int dgemm_kernel (BLASLONG , BLASLONG , BLASLONG , double , double * , double * , double * , BLASLONG );
@@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
657663int zgemm3m_kernel (BLASLONG , BLASLONG , BLASLONG , double , double , double * , double * , double * , BLASLONG );
658664int xgemm3m_kernel (BLASLONG , BLASLONG , BLASLONG , xdouble , xdouble , xdouble * , xdouble * , xdouble * , BLASLONG );
659665
666+ int shgemm_nn (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
667+ int shgemm_nt (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
668+ int shgemm_tn (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
669+ int shgemm_tt (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
670+
660671int sbgemm_nn (blas_arg_t * , BLASLONG * , BLASLONG * , bfloat16 * , bfloat16 * , BLASLONG );
661672int sbgemm_nt (blas_arg_t * , BLASLONG * , BLASLONG * , bfloat16 * , bfloat16 * , BLASLONG );
662673int sbgemm_tn (blas_arg_t * , BLASLONG * , BLASLONG * , bfloat16 * , bfloat16 * , BLASLONG );
@@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
754765int xgemm_cc (blas_arg_t * , BLASLONG * , BLASLONG * , xdouble * , xdouble * , BLASLONG );
755766#endif
756767
768+ int shgemm_thread_nn (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
769+ int shgemm_thread_nt (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
770+ int shgemm_thread_tn (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
771+ int shgemm_thread_tt (blas_arg_t * , BLASLONG * , BLASLONG * , hfloat16 * , hfloat16 * , BLASLONG );
772+
757773int sbgemm_thread_nn (blas_arg_t * , BLASLONG * , BLASLONG * , bfloat16 * , bfloat16 * , BLASLONG );
758774int sbgemm_thread_nt (blas_arg_t * , BLASLONG * , BLASLONG * , bfloat16 * , bfloat16 * , BLASLONG );
759775int sbgemm_thread_tn (blas_arg_t * , BLASLONG * , BLASLONG * , bfloat16 * , bfloat16 * , BLASLONG );
@@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
19441960int cgemm_batch_thread (blas_arg_t * queue , BLASLONG nums );
19451961int zgemm_batch_thread (blas_arg_t * queue , BLASLONG nums );
19461962int sbgemm_batch_thread (blas_arg_t * queue , BLASLONG nums );
1963+ // int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
19471964
19481965#ifdef __CUDACC__
19491966}
0 commit comments