From a208f5c983713fdea74174eb6923e053f49460b0 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 11 Sep 2025 22:57:13 +0800 Subject: [PATCH 1/9] CUDA: add a fused top-K MoE kernel This kernel does the following: 1. softmax over the logits per token [n_experts, n_tokens] 2. argmax reduce over the top-k (n_experts_used) logits 3. write weights + ids to global memory It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models --- ggml/src/ggml-cuda/ggml-cuda.cu | 47 +++++++++ ggml/src/ggml-cuda/topk-moe.cu | 165 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/topk-moe.cuh | 3 + src/llama-graph.cpp | 3 + tests/test-backend-ops.cpp | 41 ++++++++ 5 files changed, 259 insertions(+) create mode 100644 ggml/src/ggml-cuda/topk-moe.cu create mode 100644 ggml/src/ggml-cuda/topk-moe.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4d85c5dc083d1..94245125245b6 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -45,6 +45,7 @@ #include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" +#include "ggml-cuda/topk-moe.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" @@ -2825,6 +2826,40 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, GGML_ASSERT(unary_ops.size() == num_unary); #endif + //special case for topk-moe + if (ops.size() == 5 && ops.begin()[0] == GGML_OP_SOFT_MAX && ops.begin()[1] == GGML_OP_RESHAPE && ops.begin()[2] == GGML_OP_ARGSORT + && ops.begin()[3] == GGML_OP_VIEW && ops.begin()[4] == GGML_OP_GET_ROWS) { + + for (int i = 0; i < 5; i++) { + if (cgraph->nodes[node_idx + i]->op != ops.begin()[i]) return false; + } + + ggml_tensor * softmax = cgraph->nodes[node_idx]; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + + const int n_expert = softmax->ne[0]; + // n_expert must be a power of 2 + if (n_expert & (n_expert - 1) != 0 || n_expert > 512) { + return false; + } + + return true; + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -2892,6 +2927,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return true; } + + return false; } @@ -2915,6 +2952,15 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { + if (ggml_cuda_can_fuse(cgraph, i, {GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS}, {})) { + + ggml_tensor * weights = cgraph->nodes[i+4]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts); + i += 4; + continue; + } + if (node->op == GGML_OP_ADD) { int n_fuse = 0; ggml_op ops[8]; @@ -2964,6 +3010,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); continue; } + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu new file mode 100644 index 0000000000000..2950f82b2c234 --- /dev/null +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -0,0 +1,165 @@ +#include "topk-moe.cuh" + +/* + This kernel does the following: + 1. softmax over the logits per token [n_experts, n_tokens] + 2. argmax reduce over the top-k (n_experts_used) logits + 3. write weights + ids to global memory + + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models +*/ +template +__global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert_used) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + logits += n_experts * row; + ids += n_experts * row; + weights += n_expert_used * row; + + constexpr int experts_per_thread = (n_experts > 32) ? n_experts / 32 : 1; + + const int start_expert = threadIdx.x * experts_per_thread; + const int end_expert = (threadIdx.x + 1) * experts_per_thread; + float max_val = -INFINITY; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int expert = start_expert + i; + const float val = (expert < n_experts) ? logits[expert] : -INFINITY; + max_val = max(val, max_val); + } + + max_val = warp_reduce_max(max_val); + + float wt[experts_per_thread]; + float tmp = 0.f; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int expert = start_expert + i; + const float val = (expert < n_experts) ? logits[expert] : -INFINITY; + wt[i] = expf(val - max_val); + tmp += wt[i]; + } + + tmp = warp_reduce_sum(tmp); + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = wt[i] * inv_sum; + } + + //at this point, each thread holds a portion of softmax, + //we do the argmax reduce over n_expert_used, each time marking + //the expert weight as -inf to exclude from the next iteration + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + int max_expert = start_expert; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = start_expert + i; + if (wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = warpSize / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, warpSize); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, warpSize); + if (val > max_val) { + max_val = val; + max_expert = expert; + } + } + + if (max_expert >= start_expert && max_expert < end_expert) { + wt[max_expert - start_expert] = -INFINITY; + + weights[k] = max_val; + ids[k] = max_expert; + } + } +} + +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, + const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert, + const int n_expert_used) { + const int rows_per_block = 4; + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); + dim3 block_dims(32, rows_per_block, 1); + cudaStream_t stream = ctx.stream(); + + switch (n_expert) { + case 1: + topk_moe_cuda<1><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 2: + topk_moe_cuda<2><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 4: + topk_moe_cuda<4><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 8: + topk_moe_cuda<8><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 16: + topk_moe_cuda<16><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 32: + topk_moe_cuda<32><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 64: + topk_moe_cuda<64><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 128: + topk_moe_cuda<128><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 256: + topk_moe_cuda<256><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 512: + topk_moe_cuda<512><<>>(logits, weights, ids, n_rows, n_expert_used); + break; + default: + GGML_ASSERT(false && "fatal error"); + break; + } +} + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids) { + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const float * logits_d = (const float *) logits->src[0]->data; + float * weights_d = (float *) weights->data; + int32_t * ids_d = (int32_t *) ids->data; + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + + cudaStream_t stream = ctx.stream(); + + const int n_expert_used = weights->ne[1]; + + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); +} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh new file mode 100644 index 0000000000000..7fecd07dd3605 --- /dev/null +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9f2e417f1ff4b..c8f80934fbfbe 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -929,6 +929,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); + //call early so that softmax->topk->get_rows can be fused + ggml_build_forward_expand(gf, weights); + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens] diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 592631f3ed21a..b338b85e90706 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4418,6 +4418,42 @@ struct test_argsort : public test_case { } }; +struct test_topk_moe: public test_case { + + const std::array ne; + const int n_expert_used; + test_topk_moe(std::array ne = {10, 5, 1, 1}, int n_expert_used = 1) + : ne(ne), + n_expert_used(n_expert_used) { + GGML_ASSERT(n_expert_used <= ne[0]); + } + + std::string vars() override { + return VARS_TO_STR2(ne, n_expert_used); + } + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "TOPK_GATED_MOE"; + } + + bool run_whole_graph() override { return true; } + + ggml_tensor * build_graph(ggml_context * ctx) override { + const int n_expert = ne[0]; + const int n_tokens = ne[1]; + + ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_tensor * probs = ggml_soft_max(ctx, logits); + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + + ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_SUM struct test_sum : public test_case { const ggml_type type; @@ -6588,6 +6624,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3})); + + test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4)); + test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8)); + test_cases.emplace_back(new test_topk_moe({128, 19, 1, 1}, 16)); + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true)); From 9fc0396ee3758532cb1b4a24fdefef8241fd6b25 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 21 Sep 2025 09:57:10 +0800 Subject: [PATCH 2/9] Refactor into ggml_cuda_should_use_topk_moe --- ggml/src/ggml-cuda/ggml-cuda.cu | 27 ++--------------- ggml/src/ggml-cuda/topk-moe.cu | 52 ++++++++++++++++++++++++--------- ggml/src/ggml-cuda/topk-moe.cuh | 4 ++- tests/test-backend-ops.cpp | 2 +- 4 files changed, 45 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 94245125245b6..91a5162ded232 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2835,29 +2835,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } ggml_tensor * softmax = cgraph->nodes[node_idx]; - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); - - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } - - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; - } - - const int n_expert = softmax->ne[0]; - // n_expert must be a power of 2 - if (n_expert & (n_expert - 1) != 0 || n_expert > 512) { - return false; + if (ggml_cuda_should_use_topk_moe(softmax)) { + return true; } - - return true; } if (!ggml_can_fuse(cgraph, node_idx, ops)) { @@ -2927,8 +2907,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return true; } - - return false; } @@ -3010,7 +2988,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); continue; } - } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 2950f82b2c234..d553f1a1ab7a6 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -1,3 +1,4 @@ +#include "ggml.h" #include "topk-moe.cuh" /* @@ -10,10 +11,10 @@ */ template __global__ void topk_moe_cuda(const float * logits, - float * weights, - int32_t * ids, - const int n_rows, - const int n_expert_used) { + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert_used) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; @@ -94,12 +95,12 @@ __global__ void topk_moe_cuda(const float * logits, } static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, - const float * logits, - float * weights, - int32_t * ids, - const int n_rows, - const int n_expert, - const int n_expert_used) { + const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert, + const int n_expert_used) { const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(32, rows_per_block, 1); @@ -143,9 +144,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, } void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids) { + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -163,3 +164,28 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); } + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax) { + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + + const int n_expert = softmax->ne[0]; + // n_expert must be a power of 2 + if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { + return false; + } + + return true; +} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 7fecd07dd3605..0d53127df66fa 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -1,3 +1,5 @@ #include "common.cuh" -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k); +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k); + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b338b85e90706..13ac31fc93c9a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4434,7 +4434,7 @@ struct test_topk_moe: public test_case { std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); - return "TOPK_GATED_MOE"; + return "TOPK_MOE"; } bool run_whole_graph() override { return true; } From 8b780ccef9f6886c59df8f81a756f809a00d04e5 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 22 Sep 2025 22:56:29 +0800 Subject: [PATCH 3/9] Review: Use better coalescing pattern, use WARP_SIZE, store logits into registers before --- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +- ggml/src/ggml-cuda/topk-moe.cu | 68 +++++++++++++++++++-------------- ggml/src/ggml-cuda/topk-moe.cuh | 7 +++- 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 91a5162ded232..db4375c921695 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2835,7 +2835,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } ggml_tensor * softmax = cgraph->nodes[node_idx]; - if (ggml_cuda_should_use_topk_moe(softmax)) { + ggml_tensor * weights = cgraph->nodes[node_idx+4]; + if (ggml_cuda_should_use_topk_moe(softmax, weights)) { return true; } } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index d553f1a1ab7a6..9fd87d1ae020a 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -1,3 +1,4 @@ +#include "ggml-cuda/common.cuh" #include "ggml.h" #include "topk-moe.cuh" @@ -10,30 +11,36 @@ It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ template -__global__ void topk_moe_cuda(const float * logits, - float * weights, - int32_t * ids, - const int n_rows, - const int n_expert_used) { +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert_used) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; } + logits += n_experts * row; - ids += n_experts * row; weights += n_expert_used * row; + ids += n_experts * row; + + constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; - constexpr int experts_per_thread = (n_experts > 32) ? n_experts / 32 : 1; + float logits_r[experts_per_thread]; - const int start_expert = threadIdx.x * experts_per_thread; - const int end_expert = (threadIdx.x + 1) * experts_per_thread; - float max_val = -INFINITY; +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY; + } + + float max_val = -INFINITY; #pragma unroll for (int i = 0; i < experts_per_thread; i++) { - const int expert = start_expert + i; - const float val = (expert < n_experts) ? logits[expert] : -INFINITY; - max_val = max(val, max_val); + const float val = logits_r[i]; + max_val = max(val, max_val); } max_val = warp_reduce_max(max_val); @@ -43,9 +50,8 @@ __global__ void topk_moe_cuda(const float * logits, #pragma unroll for (int i = 0; i < experts_per_thread; i++) { - const int expert = start_expert + i; - const float val = (expert < n_experts) ? logits[expert] : -INFINITY; - wt[i] = expf(val - max_val); + const float val = logits_r[i]; + wt[i] = expf(val - max_val); tmp += wt[i]; } @@ -64,29 +70,29 @@ __global__ void topk_moe_cuda(const float * logits, for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; - int max_expert = start_expert; + int max_expert = threadIdx.x; #pragma unroll for (int i = 1; i < experts_per_thread; i++) { - const int expert = start_expert + i; - if (wt[i] > max_val) { + const int expert = threadIdx.x + i * WARP_SIZE; + if (expert < n_experts && wt[i] > max_val) { max_val = wt[i]; max_expert = expert; } } #pragma unroll - for (int mask = warpSize / 2; mask > 0; mask /= 2) { - const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, warpSize); - const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, warpSize); + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); if (val > max_val) { max_val = val; max_expert = expert; } } - if (max_expert >= start_expert && max_expert < end_expert) { - wt[max_expert - start_expert] = -INFINITY; + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; weights[k] = max_val; ids[k] = max_expert; @@ -103,7 +109,7 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const int n_expert_used) { const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); - dim3 block_dims(32, rows_per_block, 1); + dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); switch (n_expert) { @@ -151,12 +157,14 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + const float * logits_d = (const float *) logits->src[0]->data; float * weights_d = (float *) weights->data; int32_t * ids_d = (int32_t *) ids->data; - const int n_experts = logits->ne[0]; - const int n_rows = logits->ne[1]; + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); cudaStream_t stream = ctx.stream(); @@ -165,13 +173,17 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax) { +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { float scale = 1.0f; float max_bias = 0.0f; memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { + return false; + } + if (scale != 1.0f || max_bias != 0.0f) { return false; } diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 0d53127df66fa..03f4ad564e90d 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -1,5 +1,8 @@ #include "common.cuh" -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k); +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * top_k); -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); From ce867aa7e08623d19a45ad560d5e12e7af3e1e55 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 23 Sep 2025 09:51:26 +0800 Subject: [PATCH 4/9] Review: format + micro-optimizations --- ggml/src/ggml-cuda/topk-moe.cu | 6 +++--- tests/test-backend-ops.cpp | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 9fd87d1ae020a..1c7591e82f0c0 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -32,13 +32,13 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; - logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY; + logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; } - float max_val = -INFINITY; + float max_val = logits_r[0]; #pragma unroll - for (int i = 0; i < experts_per_thread; i++) { + for (int i = 1; i < experts_per_thread; i++) { const float val = logits_r[i]; max_val = max(val, max_val); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 13ac31fc93c9a..5605b7f618ebf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4419,12 +4419,10 @@ struct test_argsort : public test_case { }; struct test_topk_moe: public test_case { - const std::array ne; const int n_expert_used; test_topk_moe(std::array ne = {10, 5, 1, 1}, int n_expert_used = 1) - : ne(ne), - n_expert_used(n_expert_used) { + : ne(ne), n_expert_used(n_expert_used) { GGML_ASSERT(n_expert_used <= ne[0]); } From 2930668c40605d5876cb42c0a1dd588c69436c11 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 24 Sep 2025 17:35:20 +0800 Subject: [PATCH 5/9] Fix bug: fix tie breakers --- ggml/src/ggml-cuda/topk-moe.cu | 2 +- tests/test-backend-ops.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 1c7591e82f0c0..26785ef0ce1b7 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -85,7 +85,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); - if (val > max_val) { + if (val > max_val || (val == max_val && expert < max_expert)) { max_val = val; max_expert = expert; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5605b7f618ebf..7ab3c06b6bc21 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6625,7 +6625,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8)); - test_cases.emplace_back(new test_topk_moe({128, 19, 1, 1}, 16)); + test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128)); #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging From 240b2c1f020b0c4c128469111f971a6fe21316b5 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 24 Sep 2025 22:39:39 +0800 Subject: [PATCH 6/9] Add optional norm + clean-up code --- ggml/src/ggml-cuda/ggml-cuda.cu | 35 ++++++++++--- ggml/src/ggml-cuda/topk-moe.cu | 90 +++++++++++++++++++++++++++------ ggml/src/ggml-cuda/topk-moe.cuh | 8 ++- tests/test-backend-ops.cpp | 24 ++++++--- 4 files changed, 127 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index db4375c921695..266f00de9a57d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2826,12 +2826,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, GGML_ASSERT(unary_ops.size() == num_unary); #endif - //special case for topk-moe - if (ops.size() == 5 && ops.begin()[0] == GGML_OP_SOFT_MAX && ops.begin()[1] == GGML_OP_RESHAPE && ops.begin()[2] == GGML_OP_ARGSORT - && ops.begin()[3] == GGML_OP_VIEW && ops.begin()[4] == GGML_OP_GET_ROWS) { + //TODO: remove special case once ggml_can_fuse can handle empty nodes + std::initializer_list topk_moe_ops = ggml_cuda_topk_moe_ops(false); + std::initializer_list topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); - for (int i = 0; i < 5; i++) { - if (cgraph->nodes[node_idx + i]->op != ops.begin()[i]) return false; + if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { + for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false; + } + ggml_tensor * softmax = cgraph->nodes[node_idx]; + ggml_tensor * weights = cgraph->nodes[node_idx+8]; + + if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + return true; + } + } + + if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) { + for (size_t i = 0; i < topk_moe_ops.size(); i++) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false; } ggml_tensor * softmax = cgraph->nodes[node_idx]; @@ -2931,11 +2944,19 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { - if (ggml_cuda_can_fuse(cgraph, i, {GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS}, {})) { + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { + ggml_tensor * weights = cgraph->nodes[i+8]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true); + i += 8; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { ggml_tensor * weights = cgraph->nodes[i+4]; ggml_tensor * selected_experts = cgraph->nodes[i+3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts); + ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false); i += 4; continue; } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 26785ef0ce1b7..3ec3835b5a245 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -2,15 +2,18 @@ #include "ggml.h" #include "topk-moe.cuh" +#include + /* This kernel does the following: 1. softmax over the logits per token [n_experts, n_tokens] 2. argmax reduce over the top-k (n_experts_used) logits 3. write weights + ids to global memory + 4. optionally normalize the weights It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template +template __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, float * weights, int32_t * ids, @@ -68,6 +71,11 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * //we do the argmax reduce over n_expert_used, each time marking //the expert weight as -inf to exclude from the next iteration + float wt_sum = 0.f; + + extern __shared__ float data_topk_shared[]; + float * wt_shared_ptr = data_topk_shared + row * n_expert_used; + for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; @@ -94,12 +102,33 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { wt[max_expert / WARP_SIZE] = -INFINITY; - weights[k] = max_val; - ids[k] = max_expert; + wt_shared_ptr[k] = max_val; + ids[k] = max_expert; + if constexpr (with_norm) { + wt_sum += max_val; + } + } + } + + if constexpr (with_norm) { + wt_sum = warp_reduce_sum(wt_sum); + const float inv_sum = 1.0f / wt_sum; + + if (threadIdx.x == 0) { + for (int i = 0; i < n_expert_used; i++) { + wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; + } + } + } + + if (threadIdx.x == 0) { + for (int i = 0; i < n_expert_used; i++) { + weights[i] = wt_shared_ptr[i]; } } } +template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, @@ -112,36 +141,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); + switch (n_expert) { case 1: - topk_moe_cuda<1><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<1, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 2: - topk_moe_cuda<2><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<2, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 4: - topk_moe_cuda<4><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<4, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 8: - topk_moe_cuda<8><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<8, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 16: - topk_moe_cuda<16><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<16, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 32: - topk_moe_cuda<32><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<32, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 64: - topk_moe_cuda<64><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<64, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 128: - topk_moe_cuda<128><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<128, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 256: - topk_moe_cuda<256><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<256, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 512: - topk_moe_cuda<512><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<512, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); break; default: GGML_ASSERT(false && "fatal error"); @@ -152,7 +193,8 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, - ggml_tensor * ids) { + ggml_tensor * ids, + const bool with_norm) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -170,7 +212,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const int n_expert_used = weights->ne[1]; - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + if (with_norm) { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } else { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } } bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { @@ -201,3 +247,17 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso return true; } + +std::initializer_list ggml_cuda_topk_moe_ops(bool norm) { + static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; + + static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; + + if (norm) { + return norm_ops; + } + return no_norm_ops; +} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 03f4ad564e90d..6613fb56507ea 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -1,8 +1,14 @@ #include "common.cuh" +#include "ggml.h" + +#include void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, - ggml_tensor * top_k); + ggml_tensor * top_k, + const bool with_norm); bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); + +std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7ab3c06b6bc21..f87eaa8a3b0c8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4421,13 +4421,14 @@ struct test_argsort : public test_case { struct test_topk_moe: public test_case { const std::array ne; const int n_expert_used; - test_topk_moe(std::array ne = {10, 5, 1, 1}, int n_expert_used = 1) - : ne(ne), n_expert_used(n_expert_used) { + const bool with_norm; + test_topk_moe(std::array ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false) + : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) { GGML_ASSERT(n_expert_used <= ne[0]); } std::string vars() override { - return VARS_TO_STR2(ne, n_expert_used); + return VARS_TO_STR3(ne, n_expert_used, with_norm); } std::string op_desc(ggml_tensor * t) override { @@ -4447,6 +4448,14 @@ struct test_topk_moe: public test_case { ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + if (with_norm) { + out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); + ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] + + out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens] + out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); + } + ggml_set_name(out, "out"); return out; } @@ -6622,10 +6631,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3})); - - test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4)); - test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8)); - test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128)); + for (bool with_norm : {false, true}) { + test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); + test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); + } #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging From e772b28fdf74488f8281dd3d69aa36f348528fd4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 25 Sep 2025 10:26:06 +0800 Subject: [PATCH 7/9] Use smem for final write --- ggml/src/ggml-cuda/ggml-cuda.cu | 1 - ggml/src/ggml-cuda/topk-moe.cu | 4 ++-- src/llama-graph.cpp | 5 +++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 266f00de9a57d..f27951bb7d5ea 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2953,7 +2953,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i+4]; ggml_tensor * selected_experts = cgraph->nodes[i+3]; ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false); diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 3ec3835b5a245..63480d08c7077 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -74,7 +74,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt_sum = 0.f; extern __shared__ float data_topk_shared[]; - float * wt_shared_ptr = data_topk_shared + row * n_expert_used; + float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; @@ -83,7 +83,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * #pragma unroll for (int i = 1; i < experts_per_thread; i++) { const int expert = threadIdx.x + i * WARP_SIZE; - if (expert < n_experts && wt[i] > max_val) { + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { max_val = wt[i]; max_expert = expert; } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c8f80934fbfbe..b49a964b6343c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -929,8 +929,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - //call early so that softmax->topk->get_rows can be fused - ggml_build_forward_expand(gf, weights); if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); @@ -955,6 +953,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(weights, "ffn_moe_weights_scaled", il); } + //call early so that topk-moe can be used + ggml_build_forward_expand(gf, weights); + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { From 53acfe612665f21cc1eab4e6940a9669c119f749 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 25 Sep 2025 11:35:39 +0800 Subject: [PATCH 8/9] Add bounds check --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f27951bb7d5ea..8c8647b147369 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2831,6 +2831,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { + + if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) { + return false; + } + for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) { if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false; } @@ -2843,6 +2848,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) { + + if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) { + return false; + } + for (size_t i = 0; i < topk_moe_ops.size(); i++) { if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false; } From 33856e1c6b4d58d2a96647ab0550fad56b114e72 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 25 Sep 2025 16:58:55 +0800 Subject: [PATCH 9/9] Use better memory pattern for writeback --- ggml/src/ggml-cuda/topk-moe.cu | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 63480d08c7077..039f284719648 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -74,7 +74,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt_sum = 0.f; extern __shared__ float data_topk_shared[]; - float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; + float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; @@ -114,17 +114,13 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * wt_sum = warp_reduce_sum(wt_sum); const float inv_sum = 1.0f / wt_sum; - if (threadIdx.x == 0) { - for (int i = 0; i < n_expert_used; i++) { - wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; - } + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { + wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; } } - if (threadIdx.x == 0) { - for (int i = 0; i < n_expert_used; i++) { - weights[i] = wt_shared_ptr[i]; - } + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { + weights[i] = wt_shared_ptr[i]; } }