Skip to content

Commit 2e8e339

Browse files
committed
CUDA: add expert reduce kernel
1 parent e3af556 commit 2e8e339

File tree

4 files changed

+268
-0
lines changed

4 files changed

+268
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "ggml-cuda/mmq.cuh"
2828
#include "ggml-cuda/mmvf.cuh"
2929
#include "ggml-cuda/mmvq.cuh"
30+
#include "ggml-cuda/moe-expert-reduce.cuh"
3031
#include "ggml-cuda/norm.cuh"
3132
#include "ggml-cuda/opt-step-adamw.cuh"
3233
#include "ggml-cuda/opt-step-sgd.cuh"
@@ -3169,6 +3170,31 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31693170
continue;
31703171
}
31713172

3173+
if (node->op == GGML_OP_MUL) {
3174+
int current_node = i + 1;
3175+
int num_views = 0;
3176+
int num_adds = 0;
3177+
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
3178+
num_views++;
3179+
current_node++;
3180+
}
3181+
3182+
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
3183+
num_adds < num_views - 1) {
3184+
num_adds++;
3185+
current_node++;
3186+
}
3187+
3188+
if (num_adds == num_views - 1 && num_views > 0) {
3189+
ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
3190+
if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
3191+
ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
3192+
i += num_views + num_adds;
3193+
continue;
3194+
}
3195+
}
3196+
}
3197+
31723198
if (node->op == GGML_OP_ADD) {
31733199
int n_fuse = 0;
31743200
ggml_op ops[8];
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#include "moe-expert-reduce.cuh"
2+
3+
// This kernel is fusion of the expert weight reduce, common in MoE models
4+
5+
template <int n_expert_used_template>
6+
__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
7+
const float * __restrict__ weights,
8+
float * __restrict__ dst,
9+
const int n_expert_used,
10+
const int n_cols) {
11+
const int row = blockIdx.x;
12+
const int n_expert_used_t = n_expert_used_template == 0 ? n_expert_used : n_expert_used_template;
13+
14+
const int col = blockIdx.y * blockDim.x + threadIdx.x;
15+
if (col >= n_cols) {
16+
return;
17+
}
18+
19+
experts += row * n_cols * n_expert_used;
20+
weights += row * n_expert_used;
21+
dst += row * n_cols;
22+
23+
float acc = 0.f;
24+
if constexpr (n_expert_used_template == 0) {
25+
for (int expert = 0; expert < n_expert_used_t; ++expert) {
26+
ggml_cuda_mad(acc, experts[col], weights[expert]);
27+
experts += n_cols;
28+
}
29+
dst[col] = acc;
30+
} else {
31+
#pragma unroll
32+
for (int i = 0; i < n_expert_used_template; ++i) {
33+
ggml_cuda_mad(acc, experts[col], weights[i]);
34+
experts += n_cols;
35+
}
36+
dst[col] = acc;
37+
}
38+
}
39+
40+
static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
41+
const float * experts,
42+
const float * weights,
43+
float * dst,
44+
const int n_expert_used,
45+
const int n_cols,
46+
const int n_rows) {
47+
const int block_size = 32;
48+
49+
const int n_blocks_x = n_rows;
50+
const int n_blocks_y = (n_cols + block_size - 1) / block_size;
51+
52+
dim3 block_dims(block_size);
53+
dim3 grid_dims(n_blocks_x, n_blocks_y);
54+
55+
cudaStream_t stream = ctx.stream();
56+
switch (n_expert_used) {
57+
case 1:
58+
moe_expert_reduce_cuda<1>
59+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
60+
break;
61+
case 2:
62+
moe_expert_reduce_cuda<2>
63+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
64+
break;
65+
case 4:
66+
moe_expert_reduce_cuda<4>
67+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
68+
break;
69+
case 6:
70+
moe_expert_reduce_cuda<6>
71+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
72+
break;
73+
case 8:
74+
moe_expert_reduce_cuda<8>
75+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
76+
break;
77+
case 16:
78+
moe_expert_reduce_cuda<16>
79+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
80+
break;
81+
case 32:
82+
moe_expert_reduce_cuda<32>
83+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
84+
break;
85+
case 64:
86+
moe_expert_reduce_cuda<64>
87+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
88+
break;
89+
case 128:
90+
moe_expert_reduce_cuda<128>
91+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
92+
break;
93+
default:
94+
moe_expert_reduce_cuda<0>
95+
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
96+
break;
97+
}
98+
}
99+
100+
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) {
101+
const ggml_tensor * experts = cgraph->nodes[start_index];
102+
if (experts->op != GGML_OP_MUL) {
103+
return false;
104+
}
105+
106+
int current_node = start_index + 1;
107+
size_t current_offset = 0;
108+
109+
const ggml_tensor * view_nodes[32];
110+
int num_views = 0;
111+
//check if all are views of the expert in increasing order
112+
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
113+
const ggml_tensor * node = cgraph->nodes[current_node];
114+
if (node->view_src != experts) {
115+
return false;
116+
}
117+
if (node->view_offs < current_offset) {
118+
return false;
119+
}
120+
current_offset = node->view_offs;
121+
current_node++;
122+
view_nodes[num_views++] = node;
123+
124+
if (num_views >= 32) {
125+
return false;
126+
}
127+
}
128+
129+
//check if all the adds are in increasing order
130+
const ggml_tensor * prev_add_src = view_nodes[0];
131+
int num_adds = 0;
132+
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) {
133+
const ggml_tensor * add_node = cgraph->nodes[current_node];
134+
135+
bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false;
136+
bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false;
137+
138+
if (!is_first_op_ok || !is_second_op_ok) {
139+
return false;
140+
}
141+
prev_add_src = add_node;
142+
143+
num_adds++;
144+
current_node++;
145+
}
146+
147+
if (num_views != num_adds + 1) {
148+
return false;
149+
}
150+
151+
return true;
152+
}
153+
154+
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
155+
const ggml_tensor * experts,
156+
const ggml_tensor * weights,
157+
ggml_tensor * dst) {
158+
const int n_rows = experts->ne[2];
159+
const int n_expert_used = experts->ne[1];
160+
const int n_cols = experts->ne[0];
161+
162+
GGML_ASSERT(experts->type == GGML_TYPE_F32);
163+
GGML_ASSERT(weights->type == GGML_TYPE_F32);
164+
GGML_ASSERT(ggml_is_contiguous(experts));
165+
GGML_ASSERT(ggml_is_contiguous(weights));
166+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
167+
168+
const float * experts_d = (const float *) experts->data;
169+
const float * weights_d = (const float *) weights->data;
170+
float * dst_d = (float *) dst->data;
171+
172+
launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows);
173+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include "common.cuh"
2+
#include "ggml.h"
3+
4+
#include <initializer_list>
5+
6+
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
7+
const ggml_tensor * experts,
8+
const ggml_tensor * weights,
9+
ggml_tensor * dst);
10+
11+
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index);

tests/test-backend-ops.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4742,6 +4742,60 @@ struct test_topk_moe: public test_case {
47424742
}
47434743
};
47444744

4745+
struct test_moe_expert_reduce : public test_case {
4746+
const int64_t n_embd;
4747+
const int64_t n_tokens;
4748+
const int64_t n_expert_used;
4749+
4750+
test_moe_expert_reduce(int64_t n_embd = 64, int64_t n_tokens = 5, int64_t n_expert_used = 4)
4751+
: n_embd(n_embd), n_tokens(n_tokens), n_expert_used(n_expert_used) {
4752+
GGML_ASSERT(n_expert_used > 1);
4753+
}
4754+
4755+
std::string vars() override {
4756+
return VARS_TO_STR3(n_embd, n_tokens, n_expert_used);
4757+
}
4758+
4759+
std::string op_desc(ggml_tensor * t) override {
4760+
GGML_UNUSED(t);
4761+
return "MOE_EXPERT_REDUCE";
4762+
}
4763+
4764+
bool run_whole_graph() override { return true; }
4765+
4766+
ggml_tensor * build_graph(ggml_context * ctx) override {
4767+
ggml_tensor * experts = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_expert_used, n_tokens);
4768+
ggml_set_name(experts, "experts");
4769+
4770+
ggml_tensor * weights = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, n_expert_used, n_tokens);
4771+
ggml_set_name(weights, "weights");
4772+
4773+
ggml_tensor * weighted = ggml_mul(ctx, experts, weights);
4774+
ggml_set_name(weighted, "weighted_experts");
4775+
4776+
std::vector<ggml_tensor *> expert_views(n_expert_used);
4777+
for (int64_t i = 0; i < n_expert_used; ++i) {
4778+
expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[2], i * weighted->nb[1]);
4779+
4780+
std::string name = "expert_view_" + std::to_string(i);
4781+
ggml_set_name(expert_views[i], name.c_str());
4782+
ggml_build_forward_expand(gf, expert_views[i]);
4783+
}
4784+
4785+
ggml_tensor * moe_out = expert_views[0];
4786+
for (int64_t i = 1; i < n_expert_used; ++i) {
4787+
moe_out = ggml_add(ctx, moe_out, expert_views[i]);
4788+
4789+
std::string name = "expert_add_" + std::to_string(i - 1);
4790+
ggml_set_name(moe_out, name.c_str());
4791+
}
4792+
4793+
ggml_set_name(moe_out, "moe_out");
4794+
4795+
return moe_out;
4796+
}
4797+
};
4798+
47454799
struct test_mul_mat_vec_fusion : public test_case {
47464800
const ggml_type type;
47474801
const ggml_glu_op glu_op;
@@ -7179,6 +7233,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
71797233
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
71807234
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
71817235

7236+
test_cases.emplace_back(new test_moe_expert_reduce(1024, 5, 4));
7237+
test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 6));
7238+
test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 7));
7239+
71827240
#if 0
71837241
// these tests are disabled to save execution time, sbut they can be handy for debugging
71847242
test_cases.emplace_back(new test_llama(2, true));

0 commit comments

Comments
 (0)