Skip to content

Commit 18f0435

Browse files
ikawrakowIwan Kawrakow
andauthored
cuda: fused top_k+softmax as used in most MoE models (#789)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 64e357c commit 18f0435

File tree

4 files changed

+244
-19
lines changed

4 files changed

+244
-19
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "ggml-cuda/graph.cuh"
4242
#include "ggml-cuda/mmq_id.cuh"
4343
#include "ggml-cuda/quantize_id.cuh"
44+
#include "ggml-cuda/topk-moe.cuh"
4445

4546
#include <algorithm>
4647
#include <array>
@@ -3030,7 +3031,8 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
30303031

30313032
}
30323033

3033-
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) {
3034+
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next,
3035+
const ggml_cgraph * cgraph, int & i) {
30343036
// why is this here instead of mul_mat?
30353037
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
30363038
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
@@ -3152,10 +3154,10 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31523154
}
31533155
break;
31543156
case GGML_OP_MUL_MAT_ID:
3155-
skip_next = ggml_cuda_mul_mat_id(ctx, dst, next);
3157+
if (ggml_cuda_mul_mat_id(ctx, dst, next)) ++i;
31563158
break;
31573159
case GGML_OP_MOE_FUSED_UP_GATE:
3158-
skip_next = ggml_cuda_moe_up_gate_unary(ctx, dst, next);
3160+
if (ggml_cuda_moe_up_gate_unary(ctx, dst, next)) ++i;
31593161
break;
31603162
case GGML_OP_FUSED_UP_GATE:
31613163
ggml_cuda_up_gate_unary(ctx, dst);
@@ -3185,7 +3187,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31853187
ggml_cuda_op_diag_mask_inf(ctx, dst);
31863188
break;
31873189
case GGML_OP_SOFT_MAX:
3188-
ggml_cuda_op_soft_max(ctx, dst);
3190+
if (i + 4 < cgraph->n_nodes &&
3191+
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
3192+
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
3193+
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
3194+
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
3195+
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4])) {
3196+
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]);
3197+
i += 4;
3198+
} else {
3199+
ggml_cuda_op_soft_max(ctx, dst);
3200+
}
31893201
break;
31903202
case GGML_OP_SOFT_CAP_MAX:
31913203
ggml_cuda_op_soft_cap_max(ctx, dst);
@@ -3592,13 +3604,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35923604
GGML_UNUSED(integrated);
35933605
#endif // NDEBUG
35943606

3595-
bool skip_next = false;
3596-
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
3607+
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, cgraph, i);
35973608
if (!ok) {
35983609
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
35993610
}
36003611
GGML_ASSERT(ok);
3601-
if (skip_next) ++i;
36023612
}
36033613
}
36043614
#ifdef USE_CUDA_GRAPH

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#include "ggml-cuda/common.cuh"
2+
#include "ggml.h"
3+
#include "topk-moe.cuh"
4+
5+
/*
6+
This kernel does the following:
7+
1. softmax over the logits per token [n_experts, n_tokens]
8+
2. argmax reduce over the top-k (n_experts_used) logits
9+
3. write weights + ids to global memory
10+
11+
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
12+
*/
13+
template <size_t n_experts>
14+
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
15+
float * weights,
16+
int32_t * ids,
17+
const int n_rows,
18+
const int n_expert_used) {
19+
const int row = blockIdx.x * blockDim.y + threadIdx.y;
20+
if (row >= n_rows) {
21+
return;
22+
}
23+
24+
logits += n_experts * row;
25+
weights += n_expert_used * row;
26+
ids += n_experts * row;
27+
28+
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
29+
30+
float logits_r[experts_per_thread];
31+
32+
#pragma unroll
33+
for (int i = 0; i < n_experts; i += WARP_SIZE) {
34+
const int expert = i + threadIdx.x;
35+
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
36+
}
37+
38+
float max_val = logits_r[0];
39+
40+
#pragma unroll
41+
for (int i = 1; i < experts_per_thread; i++) {
42+
const float val = logits_r[i];
43+
max_val = max(val, max_val);
44+
}
45+
46+
max_val = warp_reduce_max(max_val);
47+
48+
float wt[experts_per_thread];
49+
float tmp = 0.f;
50+
51+
#pragma unroll
52+
for (int i = 0; i < experts_per_thread; i++) {
53+
const float val = logits_r[i];
54+
wt[i] = expf(val - max_val);
55+
tmp += wt[i];
56+
}
57+
58+
tmp = warp_reduce_sum(tmp);
59+
60+
const float inv_sum = 1.0f / tmp;
61+
62+
#pragma unroll
63+
for (int i = 0; i < experts_per_thread; i++) {
64+
wt[i] = wt[i] * inv_sum;
65+
}
66+
67+
//at this point, each thread holds a portion of softmax,
68+
//we do the argmax reduce over n_expert_used, each time marking
69+
//the expert weight as -inf to exclude from the next iteration
70+
71+
for (int k = 0; k < n_expert_used; k++) {
72+
float max_val = wt[0];
73+
int max_expert = threadIdx.x;
74+
75+
#pragma unroll
76+
for (int i = 1; i < experts_per_thread; i++) {
77+
const int expert = threadIdx.x + i * WARP_SIZE;
78+
if (expert < n_experts && wt[i] > max_val) {
79+
max_val = wt[i];
80+
max_expert = expert;
81+
}
82+
}
83+
84+
#pragma unroll
85+
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
86+
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
87+
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
88+
if (val > max_val) {
89+
max_val = val;
90+
max_expert = expert;
91+
}
92+
}
93+
94+
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
95+
wt[max_expert / WARP_SIZE] = -INFINITY;
96+
97+
weights[k] = max_val;
98+
ids[k] = max_expert;
99+
}
100+
}
101+
}
102+
103+
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
104+
const float * logits,
105+
float * weights,
106+
int32_t * ids,
107+
const int n_rows,
108+
const int n_expert,
109+
const int n_expert_used) {
110+
const int rows_per_block = 4;
111+
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
112+
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
113+
cudaStream_t stream = ctx.stream();
114+
115+
switch (n_expert) {
116+
case 1:
117+
topk_moe_cuda<1><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
118+
break;
119+
case 2:
120+
topk_moe_cuda<2><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
121+
break;
122+
case 4:
123+
topk_moe_cuda<4><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
124+
break;
125+
case 8:
126+
topk_moe_cuda<8><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
127+
break;
128+
case 16:
129+
topk_moe_cuda<16><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
130+
break;
131+
case 32:
132+
topk_moe_cuda<32><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
133+
break;
134+
case 64:
135+
topk_moe_cuda<64><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
136+
break;
137+
case 128:
138+
topk_moe_cuda<128><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
139+
break;
140+
case 256:
141+
topk_moe_cuda<256><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
142+
break;
143+
case 512:
144+
topk_moe_cuda<512><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
145+
break;
146+
default:
147+
GGML_ASSERT(false && "fatal error");
148+
break;
149+
}
150+
}
151+
152+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
153+
const ggml_tensor * logits,
154+
ggml_tensor * weights,
155+
ggml_tensor * ids) {
156+
GGML_ASSERT(logits->type == GGML_TYPE_F32);
157+
GGML_ASSERT(weights->type == GGML_TYPE_F32);
158+
GGML_ASSERT(ids->type == GGML_TYPE_I32);
159+
160+
const int n_experts = logits->ne[0];
161+
const int n_rows = logits->ne[1];
162+
163+
const float * logits_d = (const float *) logits->src[0]->data;
164+
float * weights_d = (float *) weights->data;
165+
int32_t * ids_d = (int32_t *) ids->data;
166+
167+
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
168+
169+
cudaStream_t stream = ctx.stream();
170+
171+
const int n_expert_used = weights->ne[1];
172+
173+
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
174+
}
175+
176+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
177+
float scale = 1.0f;
178+
float max_bias = 0.0f;
179+
180+
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
181+
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
182+
183+
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
184+
return false;
185+
}
186+
187+
if (scale != 1.0f || max_bias != 0.0f) {
188+
return false;
189+
}
190+
191+
// don't fuse when masks or sinks are present
192+
if (softmax->src[1] || softmax->src[2]) {
193+
return false;
194+
}
195+
196+
const int n_expert = softmax->ne[0];
197+
// n_expert must be a power of 2
198+
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
199+
return false;
200+
}
201+
202+
return true;
203+
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
4+
const ggml_tensor * logits,
5+
ggml_tensor * weights,
6+
ggml_tensor * top_k);
7+
8+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);

src/llama.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7950,6 +7950,10 @@ llm_expert_gating_func_type gating_op,
79507950
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
79517951
cb(weights, "ffn_moe_weights", il);
79527952

7953+
if (graph) {
7954+
ggml_build_forward_expand(graph, weights);
7955+
}
7956+
79537957
if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
79547958
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
79557959
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
@@ -8960,7 +8964,7 @@ struct llm_build_context {
89608964
LLM_FFN_SILU, false,
89618965
false, 0.0,
89628966
LLM_EXPERT_GATING_FUNC_SIGMOID,
8963-
cb, il);
8967+
cb, il, gf);
89648968

89658969
// Shared experts
89668970
ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed,
@@ -8991,7 +8995,7 @@ struct llm_build_context {
89918995
LLM_FFN_SILU, true,
89928996
false, 0.0,
89938997
LLM_EXPERT_GATING_FUNC_SOFTMAX,
8994-
cb, il);
8998+
cb, il, gf);
89958999
cb(cur, "ffn_moe_out", il);
89969000
}
89979001

@@ -9648,7 +9652,7 @@ struct llm_build_context {
96489652
LLM_FFN_GELU, true,
96499653
false, 0.0,
96509654
LLM_EXPERT_GATING_FUNC_SOFTMAX,
9651-
cb, il);
9655+
cb, il, gf);
96529656
cb(cur, "ffn_moe_out", il);
96539657

96549658
// Grok
@@ -9791,7 +9795,7 @@ struct llm_build_context {
97919795
LLM_FFN_SILU, true,
97929796
false, 0.0,
97939797
LLM_EXPERT_GATING_FUNC_SOFTMAX,
9794-
cb, il);
9798+
cb, il, gf);
97959799
cb(cur, "ffn_moe_out", il);
97969800

97979801
cur = ggml_add(ctx0, cur, ffn_inp);
@@ -10923,7 +10927,7 @@ struct llm_build_context {
1092310927
LLM_FFN_SILU, false,
1092410928
false, 0.0,
1092510929
LLM_EXPERT_GATING_FUNC_SOFTMAX,
10926-
cb, il);
10930+
cb, il, gf);
1092710931
cb(cur, "ffn_moe_out", il);
1092810932

1092910933
// FFN shared expert
@@ -11188,7 +11192,7 @@ struct llm_build_context {
1118811192
LLM_FFN_SILU, true,
1118911193
false, 0.0,
1119011194
LLM_EXPERT_GATING_FUNC_SOFTMAX,
11191-
cb, il);
11195+
cb, il, gf);
1119211196
cb(cur, "ffn_moe_out", il);
1119311197

1119411198
cur = ggml_add(ctx0, cur, ffn_inp);
@@ -13451,7 +13455,7 @@ struct llm_build_context {
1345113455
LLM_FFN_SILU, true,
1345213456
false, 0.0,
1345313457
LLM_EXPERT_GATING_FUNC_SOFTMAX,
13454-
cb, il);
13458+
cb, il, gf);
1345513459
cb(cur, "ffn_moe_out", il);
1345613460

1345713461
cur = ggml_add(ctx0, cur, ffn_out);
@@ -13940,7 +13944,7 @@ struct llm_build_context {
1394013944
LLM_FFN_SILU, hparams.expert_weights_norm,
1394113945
true, hparams.expert_weights_scale,
1394213946
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
13943-
cb, il);
13947+
cb, il, gf);
1394413948
cb(moe_out, "ffn_moe_out", il);
1394513949

1394613950
// FFN shared expert
@@ -14116,7 +14120,7 @@ struct llm_build_context {
1411614120
LLM_FFN_SILU, hparams.expert_weights_norm,
1411714121
true, hparams.expert_weights_scale,
1411814122
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
14119-
cb, il);
14123+
cb, il, gf);
1412014124
cb(routed_out, "routed_out", il);
1412114125

1412214126
{
@@ -15377,7 +15381,7 @@ struct llm_build_context {
1537715381
LLM_FFN_SILU, hparams.expert_weights_norm,
1537815382
true, hparams.expert_weights_scale,
1537915383
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
15380-
cb, il);
15384+
cb, il, gf);
1538115385
cb(moe_out, "ffn_moe_out", il);
1538215386

1538315387
{
@@ -15670,7 +15674,7 @@ struct llm_build_context {
1567015674
LLM_FFN_SILU, true,
1567115675
false, 0.0,
1567215676
LLM_EXPERT_GATING_FUNC_SOFTMAX,
15673-
cb, il);
15677+
cb, il, gf);
1567415678
cb(moe_out, "ffn_moe_out", il);
1567515679

1567615680
// Shared expert (if present)
@@ -15835,7 +15839,7 @@ struct llm_build_context {
1583515839
0.0,
1583615840
LLM_EXPERT_GATING_FUNC_SOFTMAX,
1583715841
cb,
15838-
il);
15842+
il, gf);
1583915843
cb(cur_moe, "ffn_moe_out", il);
1584015844

1584115845
ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);

0 commit comments

Comments
 (0)