@@ -288,7 +288,7 @@ static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
288288#endif
289289}
290290
291- static constexpr __host__ __device__ int calc_nwarps (ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
291+ static constexpr __host__ __device__ int calc_max_nwarps (ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
292292 if (table_id == MMVQ_PARAMETERS_GENERIC) {
293293 switch (ncols_dst) {
294294 case 1 :
@@ -387,7 +387,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
387387}
388388
389389template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false >
390- __launch_bounds__ (calc_nwarps (type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
390+ __launch_bounds__ (calc_max_nwarps (type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
391391static __global__ void mul_mat_vec_q(
392392 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
393393 const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -400,16 +400,24 @@ static __global__ void mul_mat_vec_q(
400400 constexpr int qi = ggml_cuda_type_traits<type>::qi;
401401 constexpr int vdr = get_vdr_mmvq (type);
402402 constexpr mmvq_parameter_table_id table_id = get_device_table_id ();
403- constexpr int nwarps = calc_nwarps (type, ncols_dst, table_id);
404- constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_dst, table_id, small_k, nwarps );
403+ constexpr int max_nwarps = calc_max_nwarps (type, ncols_dst, table_id);
404+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_dst, table_id, small_k, max_nwarps );
405405 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
406406
407+ // RDNA3/RDNA4: actual nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts).
408+ // Other architectures: nwarps == max_nwarps (constexpr, preserves #pragma unroll).
409+ #if defined(RDNA3_0) || defined(RDNA4)
410+ const int nwarps = max_nwarps > 1 ? static_cast <int >(blockDim .y ) : 1 ;
411+ #else
412+ constexpr int nwarps = max_nwarps;
413+ #endif
414+
407415 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
408416
409417 const int tid = warp_size*threadIdx .y + threadIdx .x ;
410418 const int row0 = rows_per_cuda_block*blockIdx .x ;
411419 const int blocks_per_row_x = ncols_x / qk;
412- constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
420+ const int blocks_per_iter = vdr * nwarps*warp_size / qi;
413421
414422 const uint32_t channel_dst = blockIdx .y ;
415423
@@ -500,8 +508,8 @@ static __global__ void mul_mat_vec_q(
500508 }
501509 }
502510
503- __shared__ float tmp_shared[nwarps -1 > 0 ? nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
504- __shared__ float tmp_shared_gate[(has_fusion && (nwarps -1 > 0 )) ? nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
511+ __shared__ float tmp_shared[max_nwarps -1 > 0 ? max_nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
512+ __shared__ float tmp_shared_gate[(has_fusion && (max_nwarps -1 > 0 )) ? max_nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
505513 if constexpr (!has_fusion) {
506514 (void ) tmp_shared_gate;
507515 } else if (!use_gate) {
@@ -653,9 +661,26 @@ static __global__ void mul_mat_vec_q_moe(
653661
654662template <ggml_type type>
655663static std::pair<dim3 , dim3 > calc_launch_params (
656- const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
664+ const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens,
657665 const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false ) {
658- const int nwarps = calc_nwarps (type, ncols_dst, table_id);
666+ int nwarps = calc_max_nwarps (type, ncols_dst, table_id);
667+
668+ // Dynamically reduce nwarps when the matrix is too narrow for full utilization.
669+ // Only applied for RDNA3/RDNA4 where MoE expert narrow matrices cause significant regression.
670+ if (nwarps > 1 && (table_id == MMVQ_PARAMETERS_RDNA3_0 || table_id == MMVQ_PARAMETERS_RDNA4)) {
671+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
672+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
673+ constexpr int vdr = get_vdr_mmvq (type);
674+ const int blocks_per_row = ncols_x / qk;
675+ const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size);
676+ if (max_useful_nwarps < nwarps) {
677+ nwarps = 1 ;
678+ for (int w = 2 ; w <= max_useful_nwarps; w *= 2 ) {
679+ nwarps = w;
680+ }
681+ }
682+ }
683+
659684 const int rpb = calc_rows_per_block (ncols_dst, table_id, small_k, nwarps);
660685 const int64_t nblocks = (nrows_x + rpb - 1 ) / rpb;
661686 const dim3 block_nums (nblocks, nchannels_dst, nsamples_or_ntokens);
@@ -746,7 +771,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
746771 constexpr int vdr = get_vdr_mmvq (type);
747772 const int blocks_per_row_x = ncols_x / qk;
748773 const int blocks_per_iter_1warp = vdr * warp_size / qi;
749- const int nwarps = calc_nwarps (type, c_ncols_dst, table_id);
774+ const int nwarps = calc_max_nwarps (type, c_ncols_dst, table_id);
750775 bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
751776
752777 constexpr std::array<ggml_type, 2 > iq_slow_turing = {
@@ -797,16 +822,16 @@ static void mul_mat_vec_q_switch_ncols_dst(
797822 bool use_small_k = should_use_small_k (c_ncols_dst);
798823
799824 if (use_small_k) {
800- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
825+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst,
801826 nsamples_dst, warp_size, table_id, true );
802827 mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true >(
803828 vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
804829 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
805830 stride_sample_x, stride_sample_y, stride_sample_dst, dims.first , dims.second , 0 , ids_stride,
806831 stream);
807832 } else {
808- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
809- nsamples_dst, warp_size, table_id);
833+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst ,
834+ warp_size, table_id);
810835 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
811836 vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
812837 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
@@ -816,55 +841,55 @@ static void mul_mat_vec_q_switch_ncols_dst(
816841 } break ;
817842 case 2 : {
818843 constexpr int c_ncols_dst = 2 ;
819- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
844+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
820845 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
821846 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
822847 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
823848 dims.first , dims.second , 0 , ids_stride, stream);
824849 } break ;
825850 case 3 : {
826851 constexpr int c_ncols_dst = 3 ;
827- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
852+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
828853 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
829854 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
830855 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
831856 dims.first , dims.second , 0 , ids_stride, stream);
832857 } break ;
833858 case 4 : {
834859 constexpr int c_ncols_dst = 4 ;
835- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
860+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
836861 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
837862 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
838863 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
839864 dims.first , dims.second , 0 , ids_stride, stream);
840865 } break ;
841866 case 5 : {
842867 constexpr int c_ncols_dst = 5 ;
843- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
868+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
844869 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
845870 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
846871 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
847872 dims.first , dims.second , 0 , ids_stride, stream);
848873 } break ;
849874 case 6 : {
850875 constexpr int c_ncols_dst = 6 ;
851- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
876+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
852877 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
853878 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
854879 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
855880 dims.first , dims.second , 0 , ids_stride, stream);
856881 } break ;
857882 case 7 : {
858883 constexpr int c_ncols_dst = 7 ;
859- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
884+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
860885 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
861886 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
862887 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
863888 dims.first , dims.second , 0 , ids_stride, stream);
864889 } break ;
865890 case 8 : {
866891 constexpr int c_ncols_dst = 8 ;
867- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
892+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
868893 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
869894 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
870895 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
0 commit comments