@@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
3333 }
3434}
3535
36- static constexpr __device__ int get_vdr_mmvq (ggml_type type) {
36+ static constexpr __host__ __device__ int get_vdr_mmvq (ggml_type type) {
3737 switch (type) {
3838 case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
3939 case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
@@ -207,16 +207,19 @@ static __global__ void mul_mat_vec_q(
207207 constexpr int qi = ggml_cuda_type_traits<type>::qi;
208208 constexpr int vdr = get_vdr_mmvq (type);
209209 constexpr mmvq_parameter_table_id table_id = get_device_table_id ();
210- constexpr int nwarps = calc_nwarps (type, ncols_dst, table_id);
210+ constexpr int max_nwarps = calc_nwarps (type, ncols_dst, table_id);
211211 constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_dst, table_id);
212212 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
213213
214+ // actual_nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts).
215+ const int nwarps = max_nwarps > 1 ? static_cast <int >(blockDim .y ) : 1 ;
216+
214217 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
215218
216219 const int tid = warp_size*threadIdx .y + threadIdx .x ;
217220 const int row0 = rows_per_cuda_block*blockIdx .x ;
218221 const int blocks_per_row_x = ncols_x / qk;
219- constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
222+ const int blocks_per_iter = vdr * nwarps*warp_size / qi;
220223
221224 const uint32_t channel_dst = blockIdx .y ;
222225
@@ -319,8 +322,8 @@ static __global__ void mul_mat_vec_q(
319322 }
320323 }
321324
322- __shared__ float tmp_shared[nwarps -1 > 0 ? nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
323- __shared__ float tmp_shared_gate[(has_fusion && (nwarps -1 > 0 )) ? nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
325+ __shared__ float tmp_shared[max_nwarps -1 > 0 ? max_nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
326+ __shared__ float tmp_shared_gate[(has_fusion && (max_nwarps -1 > 0 )) ? max_nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
324327 if constexpr (!has_fusion) {
325328 (void ) tmp_shared_gate;
326329 } else if (!use_gate) {
@@ -357,7 +360,6 @@ static __global__ void mul_mat_vec_q(
357360 for (int j = 0 ; j < ncols_dst; ++j) {
358361#pragma unroll
359362 for (int i = 0 ; i < rows_per_cuda_block; ++i) {
360- #pragma unroll
361363 for (int l = 0 ; l < nwarps-1 ; ++l) {
362364 tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
363365 if constexpr (has_fusion) {
@@ -413,11 +415,32 @@ static __global__ void mul_mat_vec_q(
413415
414416template <ggml_type type>
415417static std::pair<dim3 , dim3 > calc_launch_params (
416- const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
418+ const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens,
417419 const int warp_size, const mmvq_parameter_table_id table_id) {
420+ int nwarps = calc_nwarps (type, ncols_dst, table_id);
421+
422+ // Dynamically reduce nwarps when the matrix is too narrow for full utilization.
423+ // blocks_per_iter = vdr * nwarps * warp_size / qi must not exceed blocks_per_row = ncols_x / qk,
424+ // otherwise warps have no work but still pay reduction and syncthreads overhead.
425+ // This is critical for MoE models where expert FFN matrices are narrow.
426+ if (nwarps > 1 ) {
427+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
428+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
429+ constexpr int vdr = get_vdr_mmvq (type);
430+ const int blocks_per_row = ncols_x / qk;
431+ const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size);
432+ if (max_useful_nwarps < nwarps) {
433+ // Clamp to largest power-of-2 that fits, minimum 1
434+ nwarps = 1 ;
435+ for (int w = 2 ; w <= max_useful_nwarps; w *= 2 ) {
436+ nwarps = w;
437+ }
438+ }
439+ }
440+
418441 const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_dst, table_id) - 1 ) / calc_rows_per_block (ncols_dst, table_id);
419442 const dim3 block_nums (nblocks, nchannels_dst, nsamples_or_ntokens);
420- const dim3 block_dims (warp_size, calc_nwarps (type, ncols_dst, table_id) , 1 );
443+ const dim3 block_dims (warp_size, nwarps , 1 );
421444 return {block_nums, block_dims};
422445}
423446
@@ -477,7 +500,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
477500 if (has_ids && ncols_dst > 1 ) {
478501 // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
479502 constexpr int c_ncols_dst = 1 ;
480- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
503+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, ncols_dst, warp_size, table_id);
481504 mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true >(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
482505 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
483506 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -488,63 +511,63 @@ static void mul_mat_vec_q_switch_ncols_dst(
488511 switch (ncols_dst) {
489512 case 1 : {
490513 constexpr int c_ncols_dst = 1 ;
491- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
514+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
492515 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
493516 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
494517 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
495518 dims.first , dims.second , 0 , ids_stride, stream);
496519 } break ;
497520 case 2 : {
498521 constexpr int c_ncols_dst = 2 ;
499- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
522+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
500523 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
501524 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
502525 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
503526 dims.first , dims.second , 0 , ids_stride, stream);
504527 } break ;
505528 case 3 : {
506529 constexpr int c_ncols_dst = 3 ;
507- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
530+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
508531 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
509532 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
510533 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
511534 dims.first , dims.second , 0 , ids_stride, stream);
512535 } break ;
513536 case 4 : {
514537 constexpr int c_ncols_dst = 4 ;
515- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
538+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
516539 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
517540 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
518541 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
519542 dims.first , dims.second , 0 , ids_stride, stream);
520543 } break ;
521544 case 5 : {
522545 constexpr int c_ncols_dst = 5 ;
523- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
546+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
524547 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
525548 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
526549 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
527550 dims.first , dims.second , 0 , ids_stride, stream);
528551 } break ;
529552 case 6 : {
530553 constexpr int c_ncols_dst = 6 ;
531- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
554+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
532555 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
533556 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
534557 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
535558 dims.first , dims.second , 0 , ids_stride, stream);
536559 } break ;
537560 case 7 : {
538561 constexpr int c_ncols_dst = 7 ;
539- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
562+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
540563 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
541564 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
542565 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
543566 dims.first , dims.second , 0 , ids_stride, stream);
544567 } break ;
545568 case 8 : {
546569 constexpr int c_ncols_dst = 8 ;
547- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
570+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
548571 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
549572 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
550573 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
0 commit comments