Skip to content

Commit b1f4a5f

Browse files
committed
mmvq: add RDNA4-specific parameter table (nwarps=8, rows=1)
Add a dedicated MMVQ_PARAMETERS_RDNA4 entry separate from RDNA2/RDNA3. For bs=1 decode on RDNA4 (gfx1201), optimal config is nwarps=8 rows=1: - 8 warps × 32 threads = 256 threads per block - blocks_per_iter = vdr*nwarps*warp_size/qi = 2*8*32/4 = 128 - For K=4096: blocks_per_row=128, entire K dimension in single iteration - Maximizes memory-level parallelism on RDNA4 Benchmark (Llama 2 7B Q4_0, AMD Radeon AI PRO R9700): Master: 95.05 tok/s (tg128) nwarps=8: 104.82 tok/s (tg128) → +10.3% pp512: no regression (1448 vs 1449 tok/s)
1 parent 8004f3a commit b1f4a5f

File tree

1 file changed

+47
-16
lines changed

1 file changed

+47
-16
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
6060
enum mmvq_parameter_table_id {
6161
MMVQ_PARAMETERS_GENERIC = 0,
6262
MMVQ_PARAMETERS_GCN,
63-
MMVQ_PARAMETERS_RDNA2
63+
MMVQ_PARAMETERS_RDNA2,
64+
MMVQ_PARAMETERS_RDNA4
6465
};
6566

6667
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
67-
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
68+
#if defined(RDNA4)
69+
return MMVQ_PARAMETERS_RDNA4;
70+
#elif defined(RDNA2) || defined(RDNA3)
6871
return MMVQ_PARAMETERS_RDNA2;
6972
#elif defined(GCN) || defined(CDNA)
7073
return MMVQ_PARAMETERS_GCN;
@@ -74,7 +77,10 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
7477
}
7578

7679
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
77-
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
80+
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
81+
return MMVQ_PARAMETERS_RDNA4;
82+
}
83+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
7884
return MMVQ_PARAMETERS_RDNA2;
7985
}
8086
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
@@ -83,7 +89,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
8389
return MMVQ_PARAMETERS_GENERIC;
8490
}
8591

86-
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
92+
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
8793
if (table_id == MMVQ_PARAMETERS_GENERIC) {
8894
switch (ncols_dst) {
8995
case 1:
@@ -114,6 +120,30 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
114120
return 1;
115121
}
116122
}
123+
if (table_id == MMVQ_PARAMETERS_RDNA4) {
124+
// nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
125+
// Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
126+
// pressure and lookup table contention at higher thread counts.
127+
if (ncols_dst == 1) {
128+
switch (type) {
129+
case GGML_TYPE_Q4_0:
130+
case GGML_TYPE_Q4_1:
131+
case GGML_TYPE_Q5_0:
132+
case GGML_TYPE_Q5_1:
133+
case GGML_TYPE_Q8_0:
134+
case GGML_TYPE_Q2_K:
135+
case GGML_TYPE_Q4_K:
136+
case GGML_TYPE_Q5_K:
137+
case GGML_TYPE_Q6_K:
138+
case GGML_TYPE_IQ4_NL:
139+
case GGML_TYPE_IQ4_XS:
140+
return 8;
141+
default:
142+
return 1;
143+
}
144+
}
145+
return 1;
146+
}
117147
return 1;
118148
}
119149

@@ -138,7 +168,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
138168
}
139169

140170
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
141-
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
171+
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142172
static __global__ void mul_mat_vec_q(
143173
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
144174
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -151,7 +181,7 @@ static __global__ void mul_mat_vec_q(
151181
constexpr int qi = ggml_cuda_type_traits<type>::qi;
152182
constexpr int vdr = get_vdr_mmvq(type);
153183
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
154-
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
184+
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
155185
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
156186
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
157187

@@ -355,12 +385,13 @@ static __global__ void mul_mat_vec_q(
355385
}
356386
}
357387

388+
template<ggml_type type>
358389
static std::pair<dim3, dim3> calc_launch_params(
359390
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
360391
const int warp_size, const mmvq_parameter_table_id table_id) {
361392
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
362393
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
363-
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
394+
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
364395
return {block_nums, block_dims};
365396
}
366397

@@ -420,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
420451
if (has_ids && ncols_dst > 1) {
421452
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
422453
constexpr int c_ncols_dst = 1;
423-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
454+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
424455
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
425456
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
426457
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -431,63 +462,63 @@ static void mul_mat_vec_q_switch_ncols_dst(
431462
switch (ncols_dst) {
432463
case 1: {
433464
constexpr int c_ncols_dst = 1;
434-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
465+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
435466
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
436467
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437468
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
438469
dims.first, dims.second, 0, ids_stride, stream);
439470
} break;
440471
case 2: {
441472
constexpr int c_ncols_dst = 2;
442-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
473+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
443474
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
444475
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
445476
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
446477
dims.first, dims.second, 0, ids_stride, stream);
447478
} break;
448479
case 3: {
449480
constexpr int c_ncols_dst = 3;
450-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
481+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
451482
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
452483
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
453484
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
454485
dims.first, dims.second, 0, ids_stride, stream);
455486
} break;
456487
case 4: {
457488
constexpr int c_ncols_dst = 4;
458-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
489+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
459490
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
460491
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461492
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
462493
dims.first, dims.second, 0, ids_stride, stream);
463494
} break;
464495
case 5: {
465496
constexpr int c_ncols_dst = 5;
466-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
497+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
467498
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
468499
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
469500
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
470501
dims.first, dims.second, 0, ids_stride, stream);
471502
} break;
472503
case 6: {
473504
constexpr int c_ncols_dst = 6;
474-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
505+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
475506
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
476507
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
477508
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
478509
dims.first, dims.second, 0, ids_stride, stream);
479510
} break;
480511
case 7: {
481512
constexpr int c_ncols_dst = 7;
482-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
513+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
483514
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484515
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485516
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
486517
dims.first, dims.second, 0, ids_stride, stream);
487518
} break;
488519
case 8: {
489520
constexpr int c_ncols_dst = 8;
490-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
521+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491522
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
492523
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
493524
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,

0 commit comments

Comments
 (0)