@@ -207,16 +207,19 @@ static __global__ void mul_mat_vec_q(
207207 constexpr int qi = ggml_cuda_type_traits<type>::qi;
208208 constexpr int vdr = get_vdr_mmvq (type);
209209 constexpr mmvq_parameter_table_id table_id = get_device_table_id ();
210- constexpr int nwarps = calc_nwarps (type, ncols_dst, table_id);
211- constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_dst, table_id, small_k, nwarps );
210+ constexpr int max_nwarps = calc_nwarps (type, ncols_dst, table_id);
211+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_dst, table_id, small_k, max_nwarps );
212212 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
213213
214+ // actual_nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts).
215+ const int nwarps = max_nwarps > 1 ? static_cast <int >(blockDim .y ) : 1 ;
216+
214217 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
215218
216219 const int tid = warp_size*threadIdx .y + threadIdx .x ;
217220 const int row0 = rows_per_cuda_block*blockIdx .x ;
218221 const int blocks_per_row_x = ncols_x / qk;
219- constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
222+ const int blocks_per_iter = vdr * nwarps*warp_size / qi;
220223
221224 const uint32_t channel_dst = blockIdx .y ;
222225
@@ -319,8 +322,8 @@ static __global__ void mul_mat_vec_q(
319322 }
320323 }
321324
322- __shared__ float tmp_shared[nwarps -1 > 0 ? nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
323- __shared__ float tmp_shared_gate[(has_fusion && (nwarps -1 > 0 )) ? nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
325+ __shared__ float tmp_shared[max_nwarps -1 > 0 ? max_nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
326+ __shared__ float tmp_shared_gate[(has_fusion && (max_nwarps -1 > 0 )) ? max_nwarps -1 : 1 ][ncols_dst][rows_per_cuda_block][warp_size];
324327 if constexpr (!has_fusion) {
325328 (void ) tmp_shared_gate;
326329 } else if (!use_gate) {
@@ -357,7 +360,6 @@ static __global__ void mul_mat_vec_q(
357360 for (int j = 0 ; j < ncols_dst; ++j) {
358361#pragma unroll
359362 for (int i = 0 ; i < rows_per_cuda_block; ++i) {
360- #pragma unroll
361363 for (int l = 0 ; l < nwarps-1 ; ++l) {
362364 tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
363365 if constexpr (has_fusion) {
@@ -413,9 +415,25 @@ static __global__ void mul_mat_vec_q(
413415
414416template <ggml_type type>
415417static std::pair<dim3 , dim3 > calc_launch_params (
416- const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
418+ const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens,
417419 const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false ) {
418- const int nwarps = calc_nwarps (type, ncols_dst, table_id);
420+ int nwarps = calc_nwarps (type, ncols_dst, table_id);
421+
422+ // Dynamically reduce nwarps when the matrix is too narrow for full utilization.
423+ if (nwarps > 1 ) {
424+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
425+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
426+ constexpr int vdr = get_vdr_mmvq (type);
427+ const int blocks_per_row = ncols_x / qk;
428+ const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size);
429+ if (max_useful_nwarps < nwarps) {
430+ nwarps = 1 ;
431+ for (int w = 2 ; w <= max_useful_nwarps; w *= 2 ) {
432+ nwarps = w;
433+ }
434+ }
435+ }
436+
419437 const int rpb = calc_rows_per_block (ncols_dst, table_id, small_k, nwarps);
420438 const int64_t nblocks = (nrows_x + rpb - 1 ) / rpb;
421439 const dim3 block_nums (nblocks, nchannels_dst, nsamples_or_ntokens);
@@ -479,7 +497,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
479497 if (has_ids && ncols_dst > 1 ) {
480498 // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
481499 constexpr int c_ncols_dst = 1 ;
482- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
500+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, ncols_dst, warp_size, table_id);
483501 mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true >(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484502 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485503 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -501,15 +519,15 @@ static void mul_mat_vec_q_switch_ncols_dst(
501519 const int nwarps = calc_nwarps (type, c_ncols_dst, table_id);
502520 const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
503521 if (use_small_k) {
504- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
522+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst,
505523 warp_size, table_id, true );
506524 mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false , true >(
507525 vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
508526 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
509527 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
510528 dims.first , dims.second , 0 , ids_stride, stream);
511529 } else {
512- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
530+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst,
513531 warp_size, table_id);
514532 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
515533 vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
@@ -520,55 +538,55 @@ static void mul_mat_vec_q_switch_ncols_dst(
520538 } break ;
521539 case 2 : {
522540 constexpr int c_ncols_dst = 2 ;
523- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
541+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
524542 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
525543 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
526544 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
527545 dims.first , dims.second , 0 , ids_stride, stream);
528546 } break ;
529547 case 3 : {
530548 constexpr int c_ncols_dst = 3 ;
531- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
549+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
532550 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
533551 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
534552 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
535553 dims.first , dims.second , 0 , ids_stride, stream);
536554 } break ;
537555 case 4 : {
538556 constexpr int c_ncols_dst = 4 ;
539- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
557+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
540558 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
541559 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
542560 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
543561 dims.first , dims.second , 0 , ids_stride, stream);
544562 } break ;
545563 case 5 : {
546564 constexpr int c_ncols_dst = 5 ;
547- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
565+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
548566 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
549567 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
550568 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
551569 dims.first , dims.second , 0 , ids_stride, stream);
552570 } break ;
553571 case 6 : {
554572 constexpr int c_ncols_dst = 6 ;
555- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
573+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
556574 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
557575 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
558576 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
559577 dims.first , dims.second , 0 , ids_stride, stream);
560578 } break ;
561579 case 7 : {
562580 constexpr int c_ncols_dst = 7 ;
563- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
581+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
564582 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
565583 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
566584 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
567585 dims.first , dims.second , 0 , ids_stride, stream);
568586 } break ;
569587 case 8 : {
570588 constexpr int c_ncols_dst = 8 ;
571- std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
589+ std::pair<dim3 , dim3 > dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
572590 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
573591 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
574592 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
0 commit comments