Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
#endif
}

static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
static constexpr __host__ __device__ int calc_max_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
case 1:
Expand Down Expand Up @@ -387,7 +387,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
}

template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
__launch_bounds__(calc_max_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
Expand All @@ -400,16 +400,24 @@ static __global__ void mul_mat_vec_q(
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
constexpr int max_nwarps = calc_max_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, max_nwarps);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

// RDNA3/RDNA4: actual nwarps is set by the host based on ncols_x; may be < max_nwarps for narrow matrices (e.g. MoE experts).
// Other architectures: nwarps == max_nwarps (constexpr, preserves #pragma unroll).
#if defined(RDNA3_0) || defined(RDNA4)
const int nwarps = max_nwarps > 1 ? static_cast<int>(blockDim.y) : 1;
#else
constexpr int nwarps = max_nwarps;
#endif

constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);

const int tid = warp_size*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
const int blocks_per_iter = vdr * nwarps*warp_size / qi;

const uint32_t channel_dst = blockIdx.y;

Expand Down Expand Up @@ -500,8 +508,8 @@ static __global__ void mul_mat_vec_q(
}
}

__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
__shared__ float tmp_shared[max_nwarps-1 > 0 ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
__shared__ float tmp_shared_gate[(has_fusion && (max_nwarps-1 > 0)) ? max_nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
if constexpr (!has_fusion) {
(void) tmp_shared_gate;
} else if (!use_gate) {
Expand Down Expand Up @@ -653,9 +661,26 @@ static __global__ void mul_mat_vec_q_moe(

template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int ncols_dst, const int nrows_x, const int ncols_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
int nwarps = calc_max_nwarps(type, ncols_dst, table_id);

// Dynamically reduce nwarps when the matrix is too narrow for full utilization.
// Only applied for RDNA3/RDNA4 where MoE expert narrow matrices cause significant regression.
if (nwarps > 1 && (table_id == MMVQ_PARAMETERS_RDNA3_0 || table_id == MMVQ_PARAMETERS_RDNA4)) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row = ncols_x / qk;
const int max_useful_nwarps = (blocks_per_row * qi) / (vdr * warp_size);
if (max_useful_nwarps < nwarps) {
nwarps = 1;
for (int w = 2; w <= max_useful_nwarps; w *= 2) {
nwarps = w;
}
}
}

const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
Expand Down Expand Up @@ -746,7 +771,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const int nwarps = calc_max_nwarps(type, c_ncols_dst, table_id);
bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;

constexpr std::array<ggml_type, 2> iq_slow_turing = {
Expand Down Expand Up @@ -797,16 +822,16 @@ static void mul_mat_vec_q_switch_ncols_dst(
bool use_small_k = should_use_small_k(c_ncols_dst);

if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst,
nsamples_dst, warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
Expand All @@ -816,55 +841,55 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 2: {
constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 3: {
constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 4: {
constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 5: {
constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 6: {
constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 7: {
constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} break;
case 8: {
constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, ncols_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
Expand Down