Skip to content

Commit a3a886c

Browse files
committed
CUDA: add fp kernel for larger batch size MoE -> PR 16512
* CUDA: kernel for larger batch sizes for MoE * WIP * WIP * WIP * WIP * WIP * WIP * fixup * tests * Move mmq_ids_helper to mmid * cleanup * Remove redundant checks Author : Aman Gupta
1 parent 8596bde commit a3a886c

File tree

5 files changed

+525
-203
lines changed

5 files changed

+525
-203
lines changed

ggml/src/ggml-cuda/mmf.cu

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "ggml.h"
22
#include "mmf.cuh"
3+
#include "mmid.cuh"
4+
35

46
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
57
GGML_ASSERT( src1->type == GGML_TYPE_F32);
@@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
3739
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
3840
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
3941

42+
mmf_ids_data ids_info{};
43+
mmf_ids_data * ids_info_ptr = nullptr;
44+
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
45+
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
46+
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
47+
4048
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
4149
const int64_t ncols_dst = ids ? ne2 : ne1;
4250
const int64_t nchannels_dst = ids ? ne1 : ne2;
@@ -54,30 +62,57 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
5462
nchannels_y = ids->ne[0];
5563
}
5664

65+
if (ids && ncols_dst > 16) {
66+
const int64_t n_expert_used = ids->ne[0];
67+
const int64_t n_experts = ne02;
68+
const int64_t n_tokens = ne12;
69+
const int64_t ne_get_rows = n_tokens * n_expert_used;
70+
71+
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
72+
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
73+
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
74+
75+
const int si1 = static_cast<int>(ids_s1);
76+
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
77+
78+
GGML_ASSERT(sis1 > 0);
79+
80+
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
81+
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
82+
CUDA_CHECK(cudaGetLastError());
83+
84+
ids_info.ids_src_compact = ids_src_compact_dev.get();
85+
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
86+
ids_info.expert_bounds_dev = expert_bounds_dev.get();
87+
ids_info.n_experts = static_cast<int>(n_experts);
88+
ids_info.sis1 = sis1;
89+
ids_info_ptr = &ids_info;
90+
}
91+
5792
switch (src0->type) {
5893
case GGML_TYPE_F32: {
5994
const float * src0_d = (const float *) src0->data;
6095
constexpr int vals_per_T = 1;
6196
mul_mat_f_switch_cols_per_block(
6297
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
6398
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
64-
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
99+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
65100
} break;
66101
case GGML_TYPE_F16: {
67102
const half2 * src0_d = (const half2 *) src0->data;
68103
constexpr int vals_per_T = 2;
69104
mul_mat_f_switch_cols_per_block(
70105
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
71106
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
72-
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
107+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
73108
} break;
74109
case GGML_TYPE_BF16: {
75110
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
76111
constexpr int vals_per_T = 2;
77112
mul_mat_f_switch_cols_per_block(
78113
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
79114
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
80-
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
115+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
81116
} break;
82117
default:
83118
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
98133
}
99134

100135
if (mul_mat_id) {
101-
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
136+
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
102137
return false;
103-
}
104-
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
138+
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
105139
return false;
106140
}
107141
} else {

0 commit comments

Comments
 (0)