Skip to content

Commit 5632159

Browse files
committed
CUDA: topk-moe: add optional parameter for gpt-oss
1 parent 4926419 commit 5632159

File tree

4 files changed

+158
-61
lines changed

4 files changed

+158
-61
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2818,8 +2818,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28182818
#endif
28192819

28202820
//TODO: remove special case once ggml_can_fuse can handle empty nodes
2821-
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2822-
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
2821+
std::initializer_list<enum ggml_op> topk_moe_ops =
2822+
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
2823+
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
2824+
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
2825+
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
2826+
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
28232827

28242828
if (ops.size() == topk_moe_ops_with_norm.size() &&
28252829
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
@@ -2840,6 +2844,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28402844
}
28412845
}
28422846

2847+
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
2848+
std::equal(ops.begin(), ops.end(), topk_moe_ops_delayed_softmax.begin())) {
2849+
if (node_idx + topk_moe_ops_delayed_softmax.size() > (size_t) cgraph->n_nodes) {
2850+
return false;
2851+
}
2852+
for (size_t i = 0; i < topk_moe_ops_delayed_softmax.size(); i++) {
2853+
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_delayed_softmax.begin()[i]) {
2854+
return false;
2855+
}
2856+
}
2857+
2858+
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
2859+
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
2860+
2861+
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2862+
return true;
2863+
}
2864+
}
2865+
28432866
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
28442867
return false;
28452868
}
@@ -2933,19 +2956,32 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29332956
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
29342957
ggml_tensor * weights = cgraph->nodes[i+8];
29352958
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2936-
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2959+
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
2960+
/*delayed softmax*/ false);
29372961
i += 8;
29382962
continue;
29392963
}
29402964

29412965
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
29422966
ggml_tensor * weights = cgraph->nodes[i+4];
29432967
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2944-
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
2968+
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
2969+
/*delayed softmax*/ false);
29452970
i += 4;
29462971
continue;
29472972
}
29482973

2974+
if (ggml_cuda_can_fuse(cgraph, i,
2975+
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
2976+
ggml_tensor * weights = cgraph->nodes[i + 5];
2977+
ggml_tensor * ids = cgraph->nodes[i + 1];
2978+
2979+
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
2980+
/*delayed_softmax*/ true);
2981+
i += 5;
2982+
continue;
2983+
}
2984+
29492985
if (node->op == GGML_OP_ADD) {
29502986
int n_fuse = 0;
29512987
ggml_op ops[8];

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 92 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,58 @@
44

55
#include <initializer_list>
66

7+
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8+
template <int experts_per_thread>
9+
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
10+
float max_val = -INFINITY;
11+
12+
#pragma unroll
13+
for (int i = 0; i < experts_per_thread; i++) {
14+
const int idx = lane + i * WARP_SIZE;
15+
if (idx < limit) {
16+
max_val = max(max_val, vals[i]);
17+
}
18+
}
19+
20+
max_val = warp_reduce_max(max_val);
21+
22+
float sum = 0.f;
23+
24+
#pragma unroll
25+
for (int i = 0; i < experts_per_thread; i++) {
26+
const int idx = lane + i * WARP_SIZE;
27+
if (idx < limit) {
28+
const float val = expf(vals[i] - max_val);
29+
vals[i] = val;
30+
sum += val;
31+
} else {
32+
vals[i] = 0.f;
33+
}
34+
}
35+
36+
sum = warp_reduce_sum(sum);
37+
38+
const float inv_sum = 1.0f / sum;
39+
40+
#pragma unroll
41+
for (int i = 0; i < experts_per_thread; i++) {
42+
const int idx = lane + i * WARP_SIZE;
43+
if (idx < limit) {
44+
vals[i] *= inv_sum;
45+
}
46+
}
47+
}
48+
749
/*
850
This kernel does the following:
9-
1. softmax over the logits per token [n_experts, n_tokens]
51+
1. optionally softmax over the logits per token [n_experts, n_tokens]
1052
2. argmax reduce over the top-k (n_experts_used) logits
1153
3. write weights + ids to global memory
12-
4. optionally normalize the weights
54+
4. optionally normalize the weights or apply softmax over the selected logits
1355
1456
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1557
*/
16-
template <int n_experts, bool with_norm>
58+
template <int n_experts, bool with_norm, bool delayed_softmax = false>
1759
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
1860
float * weights,
1961
int32_t * ids,
@@ -30,51 +72,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
3072

3173
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
3274

33-
float logits_r[experts_per_thread];
75+
float wt[experts_per_thread];
3476

3577
#pragma unroll
3678
for (int i = 0; i < n_experts; i += WARP_SIZE) {
3779
const int expert = i + threadIdx.x;
38-
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
80+
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
3981
}
4082

41-
float max_val = logits_r[0];
42-
43-
#pragma unroll
44-
for (int i = 1; i < experts_per_thread; i++) {
45-
const float val = logits_r[i];
46-
max_val = max(val, max_val);
83+
if constexpr (!delayed_softmax) {
84+
softmax_warp_inplace<experts_per_thread>(wt, n_experts, threadIdx.x);
4785
}
4886

49-
max_val = warp_reduce_max(max_val);
50-
51-
float wt[experts_per_thread];
52-
float tmp = 0.f;
53-
54-
#pragma unroll
55-
for (int i = 0; i < experts_per_thread; i++) {
56-
const float val = logits_r[i];
57-
wt[i] = expf(val - max_val);
58-
tmp += wt[i];
59-
}
87+
//at this point, each thread holds either a portion of the softmax distribution
88+
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
89+
//the expert weight as -inf to exclude from the next iteration
6090

61-
tmp = warp_reduce_sum(tmp);
91+
float wt_sum = 0.f;
6292

63-
const float inv_sum = 1.0f / tmp;
93+
float output_weights[experts_per_thread];
6494

6595
#pragma unroll
6696
for (int i = 0; i < experts_per_thread; i++) {
67-
wt[i] = wt[i] * inv_sum;
97+
output_weights[i] = 0.f;
6898
}
6999

70-
//at this point, each thread holds a portion of softmax,
71-
//we do the argmax reduce over n_expert_used, each time marking
72-
//the expert weight as -inf to exclude from the next iteration
73-
74-
float wt_sum = 0.f;
75-
76-
float output_weights[experts_per_thread];
77-
78100
for (int k = 0; k < n_expert_used; k++) {
79101
float max_val = wt[0];
80102
int max_expert = threadIdx.x;
@@ -121,6 +143,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
121143
}
122144
}
123145

146+
if constexpr (delayed_softmax) {
147+
softmax_warp_inplace<experts_per_thread>(output_weights, n_expert_used, threadIdx.x);
148+
}
149+
124150
#pragma unroll
125151
for (int i = 0; i < experts_per_thread; i++) {
126152
const int idx = i * WARP_SIZE + threadIdx.x;
@@ -130,58 +156,60 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
130156
}
131157
}
132158

133-
template <bool with_norm>
159+
template <bool with_norm, bool delayed_softmax = false>
134160
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
135161
const float * logits,
136162
float * weights,
137163
int32_t * ids,
138164
const int n_rows,
139165
const int n_expert,
140166
const int n_expert_used) {
167+
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
168+
141169
const int rows_per_block = 4;
142170
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
143171
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
144172
cudaStream_t stream = ctx.stream();
145173

146174
switch (n_expert) {
147175
case 1:
148-
topk_moe_cuda<1, with_norm>
176+
topk_moe_cuda<1, with_norm, delayed_softmax>
149177
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
150178
break;
151179
case 2:
152-
topk_moe_cuda<2, with_norm>
180+
topk_moe_cuda<2, with_norm, delayed_softmax>
153181
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
154182
break;
155183
case 4:
156-
topk_moe_cuda<4, with_norm>
184+
topk_moe_cuda<4, with_norm, delayed_softmax>
157185
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
158186
break;
159187
case 8:
160-
topk_moe_cuda<8, with_norm>
188+
topk_moe_cuda<8, with_norm, delayed_softmax>
161189
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
162190
break;
163191
case 16:
164-
topk_moe_cuda<16, with_norm>
192+
topk_moe_cuda<16, with_norm, delayed_softmax>
165193
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
166194
break;
167195
case 32:
168-
topk_moe_cuda<32, with_norm>
196+
topk_moe_cuda<32, with_norm, delayed_softmax>
169197
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
170198
break;
171199
case 64:
172-
topk_moe_cuda<64, with_norm>
200+
topk_moe_cuda<64, with_norm, delayed_softmax>
173201
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
174202
break;
175203
case 128:
176-
topk_moe_cuda<128, with_norm>
204+
topk_moe_cuda<128, with_norm, delayed_softmax>
177205
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
178206
break;
179207
case 256:
180-
topk_moe_cuda<256, with_norm>
208+
topk_moe_cuda<256, with_norm, delayed_softmax>
181209
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
182210
break;
183211
case 512:
184-
topk_moe_cuda<512, with_norm>
212+
topk_moe_cuda<512, with_norm, delayed_softmax>
185213
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
186214
break;
187215
default:
@@ -194,15 +222,16 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
194222
const ggml_tensor * logits,
195223
ggml_tensor * weights,
196224
ggml_tensor * ids,
197-
const bool with_norm) {
225+
const bool with_norm,
226+
const bool delayed_softmax) {
198227
GGML_ASSERT(logits->type == GGML_TYPE_F32);
199228
GGML_ASSERT(weights->type == GGML_TYPE_F32);
200229
GGML_ASSERT(ids->type == GGML_TYPE_I32);
201230

202231
const int n_experts = logits->ne[0];
203232
const int n_rows = logits->ne[1];
204233

205-
const float * logits_d = (const float *) logits->src[0]->data;
234+
const float * logits_d = (const float *) logits->data;
206235
float * weights_d = (float *) weights->data;
207236
int32_t * ids_d = (int32_t *) ids->data;
208237

@@ -213,7 +242,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
213242
if (with_norm) {
214243
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215244
} else {
216-
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
245+
if (delayed_softmax) {
246+
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
247+
} else {
248+
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
249+
}
217250
}
218251
}
219252

@@ -246,16 +279,27 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
246279
return true;
247280
}
248281

249-
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
282+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
250283
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
251284
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
252285
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
253286

254287
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
255288
GGML_OP_VIEW, GGML_OP_GET_ROWS };
256289

290+
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
291+
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
292+
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
293+
294+
GGML_ASSERT(!norm || !delayed_softmax);
295+
296+
if (delayed_softmax) {
297+
return delayed_softmax_ops;
298+
}
299+
257300
if (norm) {
258301
return norm_ops;
259302
}
303+
260304
return no_norm_ops;
261305
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
77
const ggml_tensor * logits,
88
ggml_tensor * weights,
9-
ggml_tensor * top_k,
10-
const bool with_norm);
9+
ggml_tensor * ids,
10+
const bool with_norm,
11+
const bool delayed_softmax = false);
1112

1213
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
1314

14-
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
15+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

0 commit comments

Comments
 (0)