@@ -253,6 +253,8 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
253253 case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
254254 case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
255255 return mmq_x >= 192 ? 64 : 32 ;
256+ default :
257+ return 0 ;
256258 }
257259 } else if (new_mma_available (cc) && mmq_x >= 48 ) {
258260 return 16 ;
@@ -285,6 +287,8 @@ static constexpr __device__ int mmq_get_granularity_device(ggml_type type, const
285287 case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
286288 case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
287289 return mmq_x >= 192 ? 64 : 32 ;
290+ default :
291+ return 0 ;
288292 }
289293}
290294#elif defined(NEW_MMA_AVAILABLE)
@@ -323,6 +327,8 @@ static int get_mmq_nwarps_host(ggml_type type, const int cc) {
323327 case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
324328 case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
325329 return 4 ;
330+ default :
331+ return 0 ;
326332 }
327333 } else {
328334 return 8 ;
@@ -355,6 +361,8 @@ static constexpr __device__ int get_mmq_nwarps_device(ggml_type type) {
355361 case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
356362 case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
357363 return 4 ;
364+ default :
365+ return 0 ;
358366 }
359367}
360368#else
@@ -3123,16 +3131,16 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
31233131
31243132// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
31253133
3126- template <ggml_type type, int mmq_x, int warp_size, bool need_check>
3134+ template <ggml_type type, int mmq_x, bool need_check>
31273135#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
3128- #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1 ) || defined(GCN)
3129- __launch_bounds__ (warp_size *get_mmq_nwarps_device (type), 2)
3136+ #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA ) || defined(GCN)
3137+ __launch_bounds__ (ggml_cuda_get_physical_warp_size() *get_mmq_nwarps_device(type), 2)
31303138#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
31313139#else
31323140#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3133- __launch_bounds__ (warp_size *get_mmq_nwarps_device (type), 1)
3141+ __launch_bounds__ (ggml_cuda_get_physical_warp_size() *get_mmq_nwarps_device(type), 1)
31343142#else
3135- __launch_bounds__ (warp_size *get_mmq_nwarps_device (type), 2)
3143+ __launch_bounds__ (ggml_cuda_get_physical_warp_size() *get_mmq_nwarps_device(type), 2)
31363144#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
31373145#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
31383146static __global__ void mul_mat_q (
@@ -3149,6 +3157,7 @@ static __global__ void mul_mat_q(
31493157 }
31503158
31513159 constexpr int nwarps = get_mmq_nwarps_device (type);
3160+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
31523161
31533162 constexpr int qk = ggml_cuda_type_traits<type>::qk;
31543163 constexpr int mmq_y = get_mmq_y_device ();
@@ -3373,7 +3382,7 @@ static __global__ void mul_mat_q(
33733382}
33743383
33753384
3376- template <ggml_type type, int mmq_x, int warp_size, bool need_check>
3385+ template <ggml_type type, int mmq_x, bool need_check>
33773386static __global__ void mul_mat_q_stream_k_fixup (
33783387 const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
33793388 const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
@@ -3384,6 +3393,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
33843393 const int64_t blocks_per_ne00 = ncols_x / qk;
33853394
33863395 constexpr int nwarps = get_mmq_nwarps_device (type);
3396+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
33873397
33883398 float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0 .0f };
33893399
@@ -3531,8 +3541,8 @@ struct mmq_args {
35313541 bool use_stream_k;
35323542};
35333543
3534- template <ggml_type type, int warp_size >
3535- static size_t mmq_get_nbytes_shared (const int mmq_x, const int mmq_y, const int cc, const int nwarps) {
3544+ template <ggml_type type>
3545+ static size_t mmq_get_nbytes_shared (const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
35363546 const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (type, mmq_y);
35373547 const int mmq_tile_x_k = mmq_get_mma_tile_x_k (type);
35383548 const size_t nbs_ids = mmq_x*sizeof (int );
@@ -3546,19 +3556,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
35463556 const int id = ggml_cuda_get_device ();
35473557 const int cc = ggml_cuda_info ().devices [id].cc ;
35483558 const int nsm = ggml_cuda_info ().devices [id].nsm ;
3549- constexpr int warp_size = ggml_cuda_get_physical_warp_size () ;
3559+ const int warp_size = ggml_cuda_info (). devices [id]. warp_size ;
35503560 const int nwarps = get_mmq_nwarps_host (type, cc);
35513561 const int mmq_y = get_mmq_y_host (cc);
35523562
35533563 const dim3 block_dims (warp_size, nwarps, 1 );
35543564
3555- const int nbytes_shared = mmq_get_nbytes_shared<type, warp_size >(mmq_x, mmq_y, cc, nwarps);
3565+ const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size , nwarps);
35563566
35573567#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
35583568 static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false };
35593569 if (!shared_memory_limit_raised[id]) {
3560- CUDA_CHECK (cudaFuncSetAttribute (mul_mat_q<type, mmq_x, warp_size, false >, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3561- CUDA_CHECK (cudaFuncSetAttribute (mul_mat_q<type, mmq_x, warp_size, true >, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3570+ CUDA_CHECK (cudaFuncSetAttribute (mul_mat_q<type, mmq_x, false >, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3571+ CUDA_CHECK (cudaFuncSetAttribute (mul_mat_q<type, mmq_x, true >, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
35623572 shared_memory_limit_raised[id] = true ;
35633573 }
35643574#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
@@ -3576,14 +3586,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
35763586 if (!args.use_stream_k ) {
35773587 if (args.nrows_x % mmq_y == 0 ) {
35783588 constexpr bool need_check = false ;
3579- mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3589+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
35803590 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , nullptr ,
35813591 args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args.ncols_y , args.nrows_dst ,
35823592 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
35833593 sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
35843594 } else {
35853595 constexpr bool need_check = true ;
3586- mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3596+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
35873597 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , nullptr ,
35883598 args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args.ncols_y , args.nrows_dst ,
35893599 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
@@ -3603,7 +3613,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
36033613
36043614 if (args.nrows_x % mmq_y == 0 ) {
36053615 constexpr bool need_check = false ;
3606- mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3616+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
36073617 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr ,
36083618 args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args.ncols_y , args.nrows_dst ,
36093619 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
@@ -3613,12 +3623,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
36133623 return ;
36143624 }
36153625
3616- mul_mat_q_stream_k_fixup<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
3626+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
36173627 (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_dst ,
36183628 args.nrows_dst , args.nchannels_y , args.stride_channel_dst , args.nsamples_y , args.stride_sample_dst );
36193629 } else {
36203630 constexpr bool need_check = true ;
3621- mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3631+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
36223632 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr ,
36233633 args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args.ncols_y , args.nrows_dst ,
36243634 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
@@ -3628,19 +3638,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
36283638 return ;
36293639 }
36303640
3631- mul_mat_q_stream_k_fixup<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
3641+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
36323642 (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_dst ,
36333643 args.nrows_dst , args.nchannels_y , args.stride_channel_dst , args.nsamples_y , args.stride_sample_dst );
36343644 }
36353645}
36363646
36373647template <ggml_type type>
36383648void mul_mat_q_case (ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3639- const int id = ggml_cuda_get_device ();
3640- const int cc = ggml_cuda_info ().devices [id].cc ;
3641- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
3642- constexpr int warp_size = ggml_cuda_get_physical_warp_size () ;
3643- const int nwarps = get_mmq_nwarps_host (type, cc);
3649+ const int id = ggml_cuda_get_device ();
3650+ const int cc = ggml_cuda_info ().devices [id].cc ;
3651+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
3652+ const int warp_size = ggml_cuda_info (). devices [id]. warp_size ;
3653+ const int nwarps = get_mmq_nwarps_host (type, cc);
36443654
36453655 const int mmq_x_max = get_mmq_x_max_host (cc);
36463656 const int mmq_y = get_mmq_y_host (cc);
@@ -3651,7 +3661,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
36513661 for (int mmq_x = 8 ; mmq_x <= mmq_x_max && ntiles_x_best > 1 ; mmq_x += 8 ) {
36523662 const int granularity = mmq_get_granularity_host (type, mmq_x, cc);
36533663
3654- if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type, warp_size >(mmq_x, mmq_y, cc, nwarps) > smpbo) {
3664+ if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size , nwarps) > smpbo) {
36553665 continue ;
36563666 }
36573667
0 commit comments