Skip to content

Commit 23bae61

Browse files
committed
mmvq: dynamic nwarps based on matrix width for MoE models
1 parent c1b9116 commit 23bae61

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
3333
}
3434
}
3535

36-
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
36+
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
3737
switch (type) {
3838
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
3939
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
@@ -207,16 +207,19 @@ static __global__ void mul_mat_vec_q(
207207
constexpr int qi = ggml_cuda_type_traits<type>::qi;
208208
constexpr int vdr = get_vdr_mmvq(type);
209209
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
210-
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
210+
constexpr int max_nwarps = calc_nwarps(type, ncols_dst, table_id);
211211
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
212212
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
213213

214+
// actual_nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts).
215+
const int nwarps = max_nwarps > 1 ? static_cast<int>(blockDim.y) : 1;
216+
214217
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
215218

216219
const int tid = warp_size*threadIdx.y + threadIdx.x;
217220
const int row0 = rows_per_cuda_block*blockIdx.x;
218221
const int blocks_per_row_x = ncols_x / qk;
219-
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
222+
const int blocks_per_iter = vdr * nwarps*warp_size / qi;
220223

221224
const uint32_t channel_dst = blockIdx.y;
222225

@@ -319,8 +322,8 @@ static __global__ void mul_mat_vec_q(
319322
}
320323
}
321324

322-
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
323-
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
325+
__shared__ float tmp_shared[max_nwarps-1 > 0 ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
326+
__shared__ float tmp_shared_gate[(has_fusion && (max_nwarps-1 > 0)) ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
324327
if constexpr (!has_fusion) {
325328
(void) tmp_shared_gate;
326329
} else if (!use_gate) {
@@ -357,7 +360,6 @@ static __global__ void mul_mat_vec_q(
357360
for (int j = 0; j < ncols_dst; ++j) {
358361
#pragma unroll
359362
for (int i = 0; i < rows_per_cuda_block; ++i) {
360-
#pragma unroll
361363
for (int l = 0; l < nwarps-1; ++l) {
362364
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
363365
if constexpr (has_fusion) {
@@ -413,11 +415,32 @@ static __global__ void mul_mat_vec_q(
413415

414416
template<ggml_type type>
415417
static std::pair<dim3, dim3> calc_launch_params(
416-
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
418+
const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens,
417419
const int warp_size, const mmvq_parameter_table_id table_id) {
420+
int nwarps = calc_nwarps(type, ncols_dst, table_id);
421+
422+
// Dynamically reduce nwarps when the matrix is too narrow for full utilization.
423+
// blocks_per_iter = vdr * nwarps * warp_size / qi must not exceed blocks_per_row = ncols_x / qk,
424+
// otherwise warps have no work but still pay reduction and syncthreads overhead.
425+
// This is critical for MoE models where expert FFN matrices are narrow.
426+
if (nwarps > 1) {
427+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
428+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
429+
constexpr int vdr = get_vdr_mmvq(type);
430+
const int blocks_per_row = ncols_x / qk;
431+
const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size);
432+
if (max_useful_nwarps < nwarps) {
433+
// Clamp to largest power-of-2 that fits, minimum 1
434+
nwarps = 1;
435+
for (int w = 2; w <= max_useful_nwarps; w *= 2) {
436+
nwarps = w;
437+
}
438+
}
439+
}
440+
418441
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
419442
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
420-
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
443+
const dim3 block_dims(warp_size, nwarps, 1);
421444
return {block_nums, block_dims};
422445
}
423446

@@ -477,7 +500,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
477500
if (has_ids && ncols_dst > 1) {
478501
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
479502
constexpr int c_ncols_dst = 1;
480-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
503+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, ncols_dst, warp_size, table_id);
481504
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
482505
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
483506
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -488,63 +511,63 @@ static void mul_mat_vec_q_switch_ncols_dst(
488511
switch (ncols_dst) {
489512
case 1: {
490513
constexpr int c_ncols_dst = 1;
491-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
514+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
492515
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
493516
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
494517
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
495518
dims.first, dims.second, 0, ids_stride, stream);
496519
} break;
497520
case 2: {
498521
constexpr int c_ncols_dst = 2;
499-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
522+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
500523
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
501524
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
502525
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
503526
dims.first, dims.second, 0, ids_stride, stream);
504527
} break;
505528
case 3: {
506529
constexpr int c_ncols_dst = 3;
507-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
530+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
508531
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
509532
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
510533
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
511534
dims.first, dims.second, 0, ids_stride, stream);
512535
} break;
513536
case 4: {
514537
constexpr int c_ncols_dst = 4;
515-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
538+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
516539
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
517540
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
518541
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
519542
dims.first, dims.second, 0, ids_stride, stream);
520543
} break;
521544
case 5: {
522545
constexpr int c_ncols_dst = 5;
523-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
546+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
524547
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
525548
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
526549
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
527550
dims.first, dims.second, 0, ids_stride, stream);
528551
} break;
529552
case 6: {
530553
constexpr int c_ncols_dst = 6;
531-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
554+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
532555
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
533556
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
534557
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
535558
dims.first, dims.second, 0, ids_stride, stream);
536559
} break;
537560
case 7: {
538561
constexpr int c_ncols_dst = 7;
539-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
562+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
540563
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
541564
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
542565
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
543566
dims.first, dims.second, 0, ids_stride, stream);
544567
} break;
545568
case 8: {
546569
constexpr int c_ncols_dst = 8;
547-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
570+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
548571
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
549572
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
550573
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,

0 commit comments

Comments
 (0)