@@ -20,23 +20,27 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
2020bool  ggml_cuda_should_use_mmf (enum  ggml_type type, int  cc, int  warp_size, const  int64_t  * scr0_ne, const  int  src1_ncols, bool  mul_mat_id);
2121
2222template  <typename  T, int  rows_per_block, int  cols_per_block, int  nwarps, bool  has_ids>
23- static  __device__  void  mul_mat_f_impl (
23+ __launch_bounds__ (ggml_cuda_get_physical_warp_size()*nwarps, 1)
24+ static __global__ void mul_mat_f(
2425        const  T * __restrict__  x, const  float  * __restrict__  y, const  int32_t  * __restrict__  ids, float  * __restrict__  dst,
2526        const  int  ncols, const  int  ncols_dst_total, const  int  nchannels_dst, const  int  stride_row, const  int  stride_col_y, const  int  stride_col_dst,
2627        const  int  stride_col_id, const  int  stride_row_id,
2728        const  int  channel_ratio, const  int  stride_channel_x, const  int  stride_channel_y, const  int  stride_channel_dst,
2829        const  int  sample_ratio, const  int  stride_sample_x, const  int  stride_sample_y, const  int  stride_sample_dst) {
2930#if  !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
30- #if  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
31-     typedef  tile<32 , 8 , T>     tile_A;
32-     typedef  tile< 8 , 8 , T>     tile_B;
33-     typedef  tile<32 , 8 , float > tile_C;
34- #else 
35-     //  In principle also possible to use tiles with I == 32, the performance difference is ~1%.
36-     typedef  tile<16 , 8 , T>     tile_A;
37-     typedef  tile< 8 , 8 , T>     tile_B;
38-     typedef  tile<16 , 8 , float > tile_C;
39- #endif  //  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
31+     constexpr  bool  I_16_supported = tile<16 , 8 , T>::supported () && tile<16 , 8 , float >::supported ();
32+     constexpr  bool  I_32_supported = tile<32 , 8 , T>::supported () && tile<32 , 8 , float >::supported ();
33+ 
34+     if  (!I_16_supported && !I_32_supported) {
35+         NO_DEVICE_CODE;
36+         return ;
37+     }
38+ 
39+     constexpr  int  I_preferred = I_16_supported ? 16  : 32 ; //  For Turing MMA both work butr 16 is ~1% faster.
40+ 
41+     typedef  tile<I_preferred, 8 , T>     tile_A;
42+     typedef  tile<8 ,           8 , T>     tile_B;
43+     typedef  tile<I_preferred, 8 , float > tile_C;
4044
4145    constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
4246    constexpr  int  tile_k_padded = warp_size + 4 ;
@@ -238,43 +242,10 @@ static __device__ void mul_mat_f_impl(
238242#endif  //  !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
239243}
240244
241- template  <typename  T, int  rows_per_block, int  cols_per_block, int  nwarps, bool  has_ids>
242- __launch_bounds__ (ggml_cuda_get_physical_warp_size()*nwarps, 1)
243- static __global__ void mul_mat_f(
244-         const  T * __restrict__  x, const  float  * __restrict__  y, const  int32_t  * __restrict__  ids, float  * __restrict__  dst,
245-         const  int  ncols, const  int  ncols_dst_total, const  int  nchannels_dst, const  int  stride_row, const  int  stride_col_y, const  int  stride_col_dst,
246-         const  int  stride_col_id, const  int  stride_row_id,
247-         const  int  channel_ratio, const  int  stride_channel_x, const  int  stride_channel_y, const  int  stride_channel_dst,
248-         const  int  sample_ratio, const  int  stride_sample_x, const  int  stride_sample_y, const  int  stride_sample_dst) {
249-     if  constexpr  (std::is_same_v<T, half2>) {
250- #if  defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
251-         mul_mat_f_impl<T, rows_per_block, cols_per_block, nwarps, has_ids>(
252-             x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
253-             stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
254-             stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
255- #else 
256-         NO_DEVICE_CODE;
257- #endif  //  defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
258-     } else  if  constexpr  (std::is_same_v<T, float > || std::is_same_v<T, nv_bfloat162>) {
259- #ifdef  AMPERE_MMA_AVAILABLE
260-         mul_mat_f_impl<T, rows_per_block, cols_per_block, nwarps, has_ids>(
261-             x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
262-             stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
263-             stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
264- #else 
265-         NO_DEVICE_CODE;
266- #endif  //  AMPERE_MMA_AVAILABLE
267-     } else  {
268-         static_assert (std::is_same_v<T, void >, " bad type"  );
269-     }
270-     GGML_UNUSED_VARS (x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
271-         stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
272-         stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
273- }
274- 
275245// This kernel is for larger batch sizes of mul_mat_id
276246template  <typename  T, int  rows_per_block, int  cols_per_block, int  nwarps>
277- static  __device__  void  mul_mat_f_ids_impl (
247+ __launch_bounds__ (ggml_cuda_get_physical_warp_size()*nwarps, 1)
248+ static __global__ void mul_mat_f_ids(
278249        const  T * __restrict__  x, const  float  * __restrict__  y,
279250        const  int32_t  * __restrict__  ids_src_compact, const  int32_t  * __restrict__  ids_dst_compact,
280251        const  int32_t  * __restrict__  expert_bounds, float  * __restrict__  dst,
@@ -283,16 +254,19 @@ static __device__ void mul_mat_f_ids_impl(
283254        const  int  sample_ratio, const  int  stride_sample_x, const  int  stride_sample_y, const  int  stride_sample_dst,
284255        const  uint3  sis1_fd, const  uint3  nch_fd) {
285256#if  !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
286- #if  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
287-     typedef  tile<32 , 8 , T>     tile_A;
288-     typedef  tile< 8 , 8 , T>     tile_B;
289-     typedef  tile<32 , 8 , float > tile_C;
290- #else 
291-     //  In principle also possible to use tiles with I == 32, the performance difference is ~1%.
292-     typedef  tile<16 , 8 , T>     tile_A;
293-     typedef  tile< 8 , 8 , T>     tile_B;
294-     typedef  tile<16 , 8 , float > tile_C;
295- #endif  //  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
257+     constexpr  bool  I_16_supported = tile<16 , 8 , T>::supported () && tile<16 , 8 , float >::supported ();
258+     constexpr  bool  I_32_supported = tile<32 , 8 , T>::supported () && tile<32 , 8 , float >::supported ();
259+ 
260+     if  (!I_16_supported && !I_32_supported) {
261+         NO_DEVICE_CODE;
262+         return ;
263+     }
264+ 
265+     constexpr  int  I_preferred = I_16_supported ? 16  : 32 ; //  For Turing MMA both work butr 16 is ~1% faster.
266+ 
267+     typedef  tile<I_preferred, 8 , T>     tile_A;
268+     typedef  tile<8 ,           8 , T>     tile_B;
269+     typedef  tile<I_preferred, 8 , float > tile_C;
296270
297271    constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
298272    constexpr  int  tile_k_padded = warp_size + 4 ;
@@ -521,46 +495,6 @@ static __device__ void mul_mat_f_ids_impl(
521495#endif  //  !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
522496}
523497
524- template  <typename  T, int  rows_per_block, int  cols_per_block, int  nwarps>
525- __launch_bounds__ (ggml_cuda_get_physical_warp_size()*nwarps, 1)
526- static __global__ void mul_mat_f_ids(
527-         const  T * __restrict__  x, const  float  * __restrict__  y,
528-         const  int32_t  * __restrict__  ids_src_compact, const  int32_t  * __restrict__  ids_dst_compact,
529-         const  int32_t  * __restrict__  expert_bounds, float  * __restrict__  dst,
530-         const  int  ncols, const  int  ncols_dst_total, const  int  nchannels_dst, const  int  stride_row, const  int  stride_col_y, const  int  stride_col_dst,
531-         const  int  channel_ratio, const  int  stride_channel_x, const  int  stride_channel_y, const  int  stride_channel_dst,
532-         const  int  sample_ratio, const  int  stride_sample_x, const  int  stride_sample_y, const  int  stride_sample_dst,
533-         const  uint3  sis1_fd, const  uint3  nch_fd) {
534-     if  constexpr  (std::is_same_v<T, half2>) {
535- #if  defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
536-         mul_mat_f_ids_impl<T, rows_per_block, cols_per_block, nwarps>(
537-             x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
538-             ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
539-             channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
540-             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
541- #else 
542-         NO_DEVICE_CODE;
543- #endif  //  defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
544-     } else  if  constexpr  (std::is_same_v<T, float > || std::is_same_v<T, nv_bfloat162>) {
545- #ifdef  AMPERE_MMA_AVAILABLE
546-         mul_mat_f_ids_impl<T, rows_per_block, cols_per_block, nwarps>(
547-             x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
548-             ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
549-             channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
550-             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
551- #else 
552-         NO_DEVICE_CODE;
553- #endif  //  AMPERE_MMA_AVAILABLE
554-     } else  {
555-         static_assert (std::is_same_v<T, void >, " bad type"  );
556-     }
557-     GGML_UNUSED_VARS (
558-         x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
559-         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
560-         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
561-         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
562- }
563- 
564498template <typename  T, int  cols_per_block, int  nwarps>
565499static  inline  void  mul_mat_f_switch_ids (
566500        const  T * x, const  float  * y, const  int32_t  * ids, float  * dst,
@@ -618,7 +552,7 @@ void mul_mat_f_cuda(
618552        const  int64_t  stride_channel_x, const  int64_t  stride_channel_y, const  int64_t  stride_channel_dst, const  int64_t  nsamples_x,
619553        const  int64_t  nsamples_dst, const  int64_t  stride_sample_x, const  int64_t  stride_sample_y, const  int64_t  stride_sample_dst,
620554        cudaStream_t stream, const  mmf_ids_data * ids_data) {
621-     typedef  tile<32 , 8 , T>     tile_A_16;
555+     typedef  tile<16 , 8 , T>     tile_A_16;
622556    typedef  tile<32 , 8 , T>     tile_A_32;
623557    typedef  tile< 8 , 8 , T>     tile_B;
624558
@@ -630,7 +564,8 @@ void mul_mat_f_cuda(
630564    const  int64_t  channel_ratio = nchannels_dst / nchannels_x;
631565    const  int64_t  sample_ratio  = nsamples_dst  / nsamples_x;
632566
633-     const  int  device = ggml_cuda_get_device ();
567+     const  int  device    = ggml_cuda_get_device ();
568+     const  int  cc        = ggml_cuda_info ().devices [device].cc ;
634569    const  int  warp_size = ggml_cuda_info ().devices [device].warp_size ;
635570
636571    int64_t  nwarps_best     = 1 ;
@@ -645,7 +580,7 @@ void mul_mat_f_cuda(
645580    }
646581
647582    constexpr  int  rows_per_block = MMF_ROWS_PER_BLOCK;
648-     const  int  nbytes_shared_iter = nwarps_best * (volta_mma_available ? tile_A_32::I : tile_A_16::I) * (warp_size + 4 ) * 4 ;
583+     const  int  nbytes_shared_iter = nwarps_best * (volta_mma_available (cc)  ? tile_A_32::I : tile_A_16::I) * (warp_size + 4 ) * 4 ;
649584    const  int  nbytes_shared_combine = GGML_PAD (cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4 ) * 4 ;
650585    const  int  nbytes_shared = std::max (nbytes_shared_iter, nbytes_shared_combine);
651586    const  int  nbytes_slotmap = ids ? GGML_PAD (cols_per_block, 16 ) * sizeof (int ) : 0 ;
0 commit comments