@@ -60,11 +60,14 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
6060enum mmvq_parameter_table_id {
6161 MMVQ_PARAMETERS_GENERIC = 0 ,
6262 MMVQ_PARAMETERS_GCN,
63- MMVQ_PARAMETERS_RDNA2
63+ MMVQ_PARAMETERS_RDNA2,
64+ MMVQ_PARAMETERS_RDNA4
6465};
6566
6667static constexpr __device__ mmvq_parameter_table_id get_device_table_id () {
67- #if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
68+ #if defined(RDNA4)
69+ return MMVQ_PARAMETERS_RDNA4;
70+ #elif defined(RDNA2) || defined(RDNA3)
6871 return MMVQ_PARAMETERS_RDNA2;
6972#elif defined(GCN) || defined(CDNA)
7073 return MMVQ_PARAMETERS_GCN;
@@ -74,7 +77,10 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
7477}
7578
7679static __host__ mmvq_parameter_table_id get_device_table_id (int cc) {
77- if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
80+ if (GGML_CUDA_CC_IS_RDNA4 (cc)) {
81+ return MMVQ_PARAMETERS_RDNA4;
82+ }
83+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc)) {
7884 return MMVQ_PARAMETERS_RDNA2;
7985 }
8086 if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
@@ -83,7 +89,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
8389 return MMVQ_PARAMETERS_GENERIC;
8490}
8591
86- static constexpr __host__ __device__ int calc_nwarps (int ncols_dst, mmvq_parameter_table_id table_id) {
92+ static constexpr __host__ __device__ int calc_nwarps (ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
8793 if (table_id == MMVQ_PARAMETERS_GENERIC) {
8894 switch (ncols_dst) {
8995 case 1 :
@@ -114,6 +120,30 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
114120 return 1 ;
115121 }
116122 }
123+ if (table_id == MMVQ_PARAMETERS_RDNA4) {
124+ // nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
125+ // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
126+ // pressure and lookup table contention at higher thread counts.
127+ if (ncols_dst == 1 ) {
128+ switch (type) {
129+ case GGML_TYPE_Q4_0:
130+ case GGML_TYPE_Q4_1:
131+ case GGML_TYPE_Q5_0:
132+ case GGML_TYPE_Q5_1:
133+ case GGML_TYPE_Q8_0:
134+ case GGML_TYPE_Q2_K:
135+ case GGML_TYPE_Q4_K:
136+ case GGML_TYPE_Q5_K:
137+ case GGML_TYPE_Q6_K:
138+ case GGML_TYPE_IQ4_NL:
139+ case GGML_TYPE_IQ4_XS:
140+ return 8 ;
141+ default :
142+ return 1 ;
143+ }
144+ }
145+ return 1 ;
146+ }
117147 return 1 ;
118148}
119149
@@ -138,7 +168,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
138168}
139169
140170template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false >
141- __launch_bounds__ (calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
171+ __launch_bounds__ (calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142172static __global__ void mul_mat_vec_q(
143173 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
144174 const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -151,7 +181,7 @@ static __global__ void mul_mat_vec_q(
151181 constexpr int qi = ggml_cuda_type_traits<type>::qi;
152182 constexpr int vdr = get_vdr_mmvq (type);
153183 constexpr mmvq_parameter_table_id table_id = get_device_table_id ();
154- constexpr int nwarps = calc_nwarps (ncols_dst, table_id);
184+ constexpr int nwarps = calc_nwarps (type, ncols_dst, table_id);
155185 constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_dst, table_id);
156186 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
157187
@@ -355,12 +385,13 @@ static __global__ void mul_mat_vec_q(
355385 }
356386}
357387
388+ template <ggml_type type>
358389static std::pair<dim3 , dim3 > calc_launch_params (
359390 const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
360391 const int warp_size, const mmvq_parameter_table_id table_id) {
361392 const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_dst, table_id) - 1 ) / calc_rows_per_block (ncols_dst, table_id);
362393 const dim3 block_nums (nblocks, nchannels_dst, nsamples_or_ntokens);
363- const dim3 block_dims (warp_size, calc_nwarps (ncols_dst, table_id), 1 );
394+ const dim3 block_dims (warp_size, calc_nwarps (type, ncols_dst, table_id), 1 );
364395 return {block_nums, block_dims};
365396}
366397
@@ -420,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
420451 if (has_ids && ncols_dst > 1 ) {
421452 // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
422453 constexpr int c_ncols_dst = 1 ;
423- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
454+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
424455 mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true >(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
425456 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
426457 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -431,63 +462,63 @@ static void mul_mat_vec_q_switch_ncols_dst(
431462 switch (ncols_dst) {
432463 case 1 : {
433464 constexpr int c_ncols_dst = 1 ;
434- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
465+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
435466 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
436467 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437468 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
438469 dims.first , dims.second , 0 , ids_stride, stream);
439470 } break ;
440471 case 2 : {
441472 constexpr int c_ncols_dst = 2 ;
442- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
473+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
443474 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
444475 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
445476 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
446477 dims.first , dims.second , 0 , ids_stride, stream);
447478 } break ;
448479 case 3 : {
449480 constexpr int c_ncols_dst = 3 ;
450- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
481+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
451482 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
452483 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
453484 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
454485 dims.first , dims.second , 0 , ids_stride, stream);
455486 } break ;
456487 case 4 : {
457488 constexpr int c_ncols_dst = 4 ;
458- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
489+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
459490 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
460491 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461492 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
462493 dims.first , dims.second , 0 , ids_stride, stream);
463494 } break ;
464495 case 5 : {
465496 constexpr int c_ncols_dst = 5 ;
466- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
497+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
467498 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
468499 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
469500 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
470501 dims.first , dims.second , 0 , ids_stride, stream);
471502 } break ;
472503 case 6 : {
473504 constexpr int c_ncols_dst = 6 ;
474- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
505+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
475506 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
476507 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
477508 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
478509 dims.first , dims.second , 0 , ids_stride, stream);
479510 } break ;
480511 case 7 : {
481512 constexpr int c_ncols_dst = 7 ;
482- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
513+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
483514 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484515 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485516 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
486517 dims.first , dims.second , 0 , ids_stride, stream);
487518 } break ;
488519 case 8 : {
489520 constexpr int c_ncols_dst = 8 ;
490- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
521+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491522 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
492523 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
493524 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
0 commit comments