Skip to content

Commit 8cd2a84

Browse files
committed
mmvq: dynamic nwarps based on matrix width for MoE models
1 parent 84f82e8 commit 8cd2a84

File tree

1 file changed

+45
-20
lines changed

1 file changed

+45
-20
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
288288
#endif
289289
}
290290

291-
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
291+
static constexpr __host__ __device__ int calc_max_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
292292
if (table_id == MMVQ_PARAMETERS_GENERIC) {
293293
switch (ncols_dst) {
294294
case 1:
@@ -387,7 +387,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
387387
}
388388

389389
template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
390-
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
390+
__launch_bounds__(calc_max_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
391391
static __global__ void mul_mat_vec_q(
392392
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
393393
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -400,16 +400,24 @@ static __global__ void mul_mat_vec_q(
400400
constexpr int qi = ggml_cuda_type_traits<type>::qi;
401401
constexpr int vdr = get_vdr_mmvq(type);
402402
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
403-
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
404-
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
403+
constexpr int max_nwarps = calc_max_nwarps(type, ncols_dst, table_id);
404+
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, max_nwarps);
405405
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
406406

407+
// RDNA3/RDNA4: actual nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts).
408+
// Other architectures: nwarps == max_nwarps (constexpr, preserves #pragma unroll).
409+
#if defined(RDNA3_0) || defined(RDNA4)
410+
const int nwarps = max_nwarps > 1 ? static_cast<int>(blockDim.y) : 1;
411+
#else
412+
constexpr int nwarps = max_nwarps;
413+
#endif
414+
407415
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
408416

409417
const int tid = warp_size*threadIdx.y + threadIdx.x;
410418
const int row0 = rows_per_cuda_block*blockIdx.x;
411419
const int blocks_per_row_x = ncols_x / qk;
412-
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
420+
const int blocks_per_iter = vdr * nwarps*warp_size / qi;
413421

414422
const uint32_t channel_dst = blockIdx.y;
415423

@@ -500,8 +508,8 @@ static __global__ void mul_mat_vec_q(
500508
}
501509
}
502510

503-
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
504-
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
511+
__shared__ float tmp_shared[max_nwarps-1 > 0 ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
512+
__shared__ float tmp_shared_gate[(has_fusion && (max_nwarps-1 > 0)) ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
505513
if constexpr (!has_fusion) {
506514
(void) tmp_shared_gate;
507515
} else if (!use_gate) {
@@ -653,9 +661,26 @@ static __global__ void mul_mat_vec_q_moe(
653661

654662
template<ggml_type type>
655663
static std::pair<dim3, dim3> calc_launch_params(
656-
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
664+
const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens,
657665
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
658-
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
666+
int nwarps = calc_max_nwarps(type, ncols_dst, table_id);
667+
668+
// Dynamically reduce nwarps when the matrix is too narrow for full utilization.
669+
// Only applied for RDNA3/RDNA4 where MoE expert narrow matrices cause significant regression.
670+
if (nwarps > 1 && (table_id == MMVQ_PARAMETERS_RDNA3_0 || table_id == MMVQ_PARAMETERS_RDNA4)) {
671+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
672+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
673+
constexpr int vdr = get_vdr_mmvq(type);
674+
const int blocks_per_row = ncols_x / qk;
675+
const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size);
676+
if (max_useful_nwarps < nwarps) {
677+
nwarps = 1;
678+
for (int w = 2; w <= max_useful_nwarps; w *= 2) {
679+
nwarps = w;
680+
}
681+
}
682+
}
683+
659684
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
660685
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
661686
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
@@ -746,7 +771,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
746771
constexpr int vdr = get_vdr_mmvq(type);
747772
const int blocks_per_row_x = ncols_x / qk;
748773
const int blocks_per_iter_1warp = vdr * warp_size / qi;
749-
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
774+
const int nwarps = calc_max_nwarps(type, c_ncols_dst, table_id);
750775
bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
751776

752777
constexpr std::array<ggml_type, 2> iq_slow_turing = {
@@ -797,16 +822,16 @@ static void mul_mat_vec_q_switch_ncols_dst(
797822
bool use_small_k = should_use_small_k(c_ncols_dst);
798823

799824
if (use_small_k) {
800-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
825+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst,
801826
nsamples_dst, warp_size, table_id, true);
802827
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
803828
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
804829
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
805830
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
806831
stream);
807832
} else {
808-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
809-
nsamples_dst, warp_size, table_id);
833+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst,
834+
warp_size, table_id);
810835
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
811836
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
812837
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
@@ -816,55 +841,55 @@ static void mul_mat_vec_q_switch_ncols_dst(
816841
} break;
817842
case 2: {
818843
constexpr int c_ncols_dst = 2;
819-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
844+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
820845
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
821846
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
822847
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
823848
dims.first, dims.second, 0, ids_stride, stream);
824849
} break;
825850
case 3: {
826851
constexpr int c_ncols_dst = 3;
827-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
852+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
828853
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
829854
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
830855
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
831856
dims.first, dims.second, 0, ids_stride, stream);
832857
} break;
833858
case 4: {
834859
constexpr int c_ncols_dst = 4;
835-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
860+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
836861
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
837862
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
838863
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
839864
dims.first, dims.second, 0, ids_stride, stream);
840865
} break;
841866
case 5: {
842867
constexpr int c_ncols_dst = 5;
843-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
868+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
844869
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
845870
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
846871
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
847872
dims.first, dims.second, 0, ids_stride, stream);
848873
} break;
849874
case 6: {
850875
constexpr int c_ncols_dst = 6;
851-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
876+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
852877
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
853878
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
854879
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
855880
dims.first, dims.second, 0, ids_stride, stream);
856881
} break;
857882
case 7: {
858883
constexpr int c_ncols_dst = 7;
859-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
884+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
860885
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
861886
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
862887
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
863888
dims.first, dims.second, 0, ids_stride, stream);
864889
} break;
865890
case 8: {
866891
constexpr int c_ncols_dst = 8;
867-
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
892+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
868893
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
869894
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
870895
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,

0 commit comments

Comments
 (0)