Skip to content

Commit afb6b78

Browse files
committed
mmvq: dynamic nwarps based on matrix width for MoE models
1 parent 177c758 commit afb6b78

File tree

1 file changed

+36
-18
lines changed

1 file changed

+36
-18
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,19 @@ static __global__ void mul_mat_vec_q(
207207
constexpr int qi = ggml_cuda_type_traits<type>::qi;
208208
constexpr int vdr = get_vdr_mmvq(type);
209209
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
210-
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
211-
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
210+
constexpr int max_nwarps = calc_nwarps(type, ncols_dst, table_id);
211+
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, max_nwarps);
212212
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
213213

214+
// actual_nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts).
215+
const int nwarps = max_nwarps > 1 ? static_cast<int>(blockDim.y) : 1;
216+
214217
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
215218

216219
const int tid = warp_size*threadIdx.y + threadIdx.x;
217220
const int row0 = rows_per_cuda_block*blockIdx.x;
218221
const int blocks_per_row_x = ncols_x / qk;
219-
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
222+
const int blocks_per_iter = vdr * nwarps*warp_size / qi;
220223

221224
const uint32_t channel_dst = blockIdx.y;
222225

@@ -319,8 +322,8 @@ static __global__ void mul_mat_vec_q(
319322
}
320323
}
321324

322-
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
323-
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
325+
__shared__ float tmp_shared[max_nwarps-1 > 0 ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
326+
__shared__ float tmp_shared_gate[(has_fusion && (max_nwarps-1 > 0)) ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
324327
if constexpr (!has_fusion) {
325328
(void) tmp_shared_gate;
326329
} else if (!use_gate) {
@@ -357,7 +360,6 @@ static __global__ void mul_mat_vec_q(
357360
for (int j = 0; j < ncols_dst; ++j) {
358361
#pragma unroll
359362
for (int i = 0; i < rows_per_cuda_block; ++i) {
360-
#pragma unroll
361363
for (int l = 0; l < nwarps-1; ++l) {
362364
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
363365
if constexpr (has_fusion) {
@@ -413,9 +415,25 @@ static __global__ void mul_mat_vec_q(
413415

414416
template<ggml_type type>
415417
static std::pair<dim3, dim3> calc_launch_params(
416-
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
418+
const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens,
417419
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
418-
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
420+
int nwarps = calc_nwarps(type, ncols_dst, table_id);
421+
422+
// Dynamically reduce nwarps when the matrix is too narrow for full utilization.
423+
if (nwarps > 1) {
424+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
425+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
426+
constexpr int vdr = get_vdr_mmvq(type);
427+
const int blocks_per_row = ncols_x / qk;
428+
const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size);
429+
if (max_useful_nwarps < nwarps) {
430+
nwarps = 1;
431+
for (int w = 2; w <= max_useful_nwarps; w *= 2) {
432+
nwarps = w;
433+
}
434+
}
435+
}
436+
419437
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
420438
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
421439
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
@@ -479,7 +497,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
479497
if (has_ids && ncols_dst > 1) {
480498
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
481499
constexpr int c_ncols_dst = 1;
482-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
500+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, ncols_dst, warp_size, table_id);
483501
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484502
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485503
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -501,15 +519,15 @@ static void mul_mat_vec_q_switch_ncols_dst(
501519
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
502520
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
503521
if (use_small_k) {
504-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
522+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst,
505523
warp_size, table_id, true);
506524
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
507525
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
508526
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
509527
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
510528
dims.first, dims.second, 0, ids_stride, stream);
511529
} else {
512-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
530+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst,
513531
warp_size, table_id);
514532
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
515533
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
@@ -520,55 +538,55 @@ static void mul_mat_vec_q_switch_ncols_dst(
520538
} break;
521539
case 2: {
522540
constexpr int c_ncols_dst = 2;
523-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
541+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
524542
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
525543
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
526544
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
527545
dims.first, dims.second, 0, ids_stride, stream);
528546
} break;
529547
case 3: {
530548
constexpr int c_ncols_dst = 3;
531-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
549+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
532550
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
533551
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
534552
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
535553
dims.first, dims.second, 0, ids_stride, stream);
536554
} break;
537555
case 4: {
538556
constexpr int c_ncols_dst = 4;
539-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
557+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
540558
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
541559
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
542560
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
543561
dims.first, dims.second, 0, ids_stride, stream);
544562
} break;
545563
case 5: {
546564
constexpr int c_ncols_dst = 5;
547-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
565+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
548566
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
549567
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
550568
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
551569
dims.first, dims.second, 0, ids_stride, stream);
552570
} break;
553571
case 6: {
554572
constexpr int c_ncols_dst = 6;
555-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
573+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
556574
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
557575
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
558576
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
559577
dims.first, dims.second, 0, ids_stride, stream);
560578
} break;
561579
case 7: {
562580
constexpr int c_ncols_dst = 7;
563-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
581+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
564582
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
565583
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
566584
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
567585
dims.first, dims.second, 0, ids_stride, stream);
568586
} break;
569587
case 8: {
570588
constexpr int c_ncols_dst = 8;
571-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
589+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
572590
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
573591
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
574592
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,

0 commit comments

Comments
 (0)