diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 07b10167bc4..829c2069e01 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -288,7 +288,7 @@ static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { #endif } -static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { +static constexpr __host__ __device__ int calc_max_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { case 1: @@ -387,7 +387,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int } template -__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +__launch_bounds__(calc_max_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -400,16 +400,24 @@ static __global__ void mul_mat_vec_q( constexpr int qi = ggml_cuda_type_traits::qi; constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); - constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + constexpr int max_nwarps = calc_max_nwarps(type, ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, max_nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + // RDNA3/RDNA4: actual nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts). + // Other architectures: nwarps == max_nwarps (constexpr, preserves #pragma unroll). +#if defined(RDNA3_0) || defined(RDNA4) + const int nwarps = max_nwarps > 1 ? static_cast(blockDim.y) : 1; +#else + constexpr int nwarps = max_nwarps; +#endif + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); const int tid = warp_size*threadIdx.y + threadIdx.x; const int row0 = rows_per_cuda_block*blockIdx.x; const int blocks_per_row_x = ncols_x / qk; - constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; + const int blocks_per_iter = vdr * nwarps*warp_size / qi; const uint32_t channel_dst = blockIdx.y; @@ -500,8 +508,8 @@ static __global__ void mul_mat_vec_q( } } - __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared[max_nwarps-1 > 0 ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared_gate[(has_fusion && (max_nwarps-1 > 0)) ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; if constexpr (!has_fusion) { (void) tmp_shared_gate; } else if (!use_gate) { @@ -653,9 +661,26 @@ static __global__ void mul_mat_vec_q_moe( template static std::pair calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, + const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens, const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { - const int nwarps = calc_nwarps(type, ncols_dst, table_id); + int nwarps = calc_max_nwarps(type, ncols_dst, table_id); + + // Dynamically reduce nwarps when the matrix is too narrow for full utilization. + // Only applied for RDNA3/RDNA4 where MoE expert narrow matrices cause significant regression. + if (nwarps > 1 && (table_id == MMVQ_PARAMETERS_RDNA3_0 || table_id == MMVQ_PARAMETERS_RDNA4)) { + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row = ncols_x / qk; + const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size); + if (max_useful_nwarps < nwarps) { + nwarps = 1; + for (int w = 2; w <= max_useful_nwarps; w *= 2) { + nwarps = w; + } + } + } + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); @@ -746,7 +771,7 @@ static void mul_mat_vec_q_switch_ncols_dst( constexpr int vdr = get_vdr_mmvq(type); const int blocks_per_row_x = ncols_x / qk; const int blocks_per_iter_1warp = vdr * warp_size / qi; - const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const int nwarps = calc_max_nwarps(type, c_ncols_dst, table_id); bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; constexpr std::array iq_slow_turing = { @@ -797,7 +822,7 @@ static void mul_mat_vec_q_switch_ncols_dst( bool use_small_k = should_use_small_k(c_ncols_dst); if (use_small_k) { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id, true); mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, @@ -805,8 +830,8 @@ static void mul_mat_vec_q_switch_ncols_dst( stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, stream); } else { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, - nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, + warp_size, table_id); mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, @@ -816,7 +841,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 2: { constexpr int c_ncols_dst = 2; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -824,7 +849,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 3: { constexpr int c_ncols_dst = 3; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -832,7 +857,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 4: { constexpr int c_ncols_dst = 4; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -840,7 +865,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 5: { constexpr int c_ncols_dst = 5; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -848,7 +873,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 6: { constexpr int c_ncols_dst = 6; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -856,7 +881,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 7: { constexpr int c_ncols_dst = 7; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -864,7 +889,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 8: { constexpr int c_ncols_dst = 8; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,