diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75fd6db14c514..24a57bcd2faf4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2818,8 +2818,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, #endif //TODO: remove special case once ggml_can_fuse can handle empty nodes - std::initializer_list topk_moe_ops = ggml_cuda_topk_moe_ops(false); - std::initializer_list topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); + std::initializer_list topk_moe_ops = + ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false); + std::initializer_list topk_moe_ops_with_norm = + ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false); + std::initializer_list topk_moe_ops_delayed_softmax = + ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { @@ -2855,6 +2859,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } + if (ops.size() == topk_moe_ops_delayed_softmax.size() && + std::equal(ops.begin(), ops.end(), topk_moe_ops_delayed_softmax.begin())) { + if (node_idx + topk_moe_ops_delayed_softmax.size() > (size_t) cgraph->n_nodes) { + return false; + } + for (size_t i = 0; i < topk_moe_ops_delayed_softmax.size(); i++) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_delayed_softmax.begin()[i]) { + return false; + } + } + + ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; + ggml_tensor * weights = cgraph->nodes[node_idx + 5]; + + if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + return true; + } + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -2948,7 +2971,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { ggml_tensor * weights = cgraph->nodes[i+8]; ggml_tensor * selected_experts = cgraph->nodes[i+3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true); + ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, + /*delayed softmax*/ false); i += 8; continue; } @@ -2956,11 +2980,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { ggml_tensor * weights = cgraph->nodes[i+4]; ggml_tensor * selected_experts = cgraph->nodes[i+3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false); + ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, + /*delayed softmax*/ false); i += 4; continue; } + if (ggml_cuda_can_fuse(cgraph, i, + ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { + ggml_tensor * weights = cgraph->nodes[i + 5]; + ggml_tensor * ids = cgraph->nodes[i + 1]; + + ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, + /*delayed_softmax*/ true); + i += 5; + continue; + } + if (node->op == GGML_OP_ADD) { int n_fuse = 0; ggml_op ops[8]; diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index c588da2bb9e93..d782ad948d254 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -4,16 +4,61 @@ #include +// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. +template +__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { + float max_val = -INFINITY; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + max_val = max(max_val, vals[i]); + } + } + + max_val = warp_reduce_max(max_val); + + float sum = 0.f; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + const float val = expf(vals[i] - max_val); + vals[i] = val; + sum += val; + } else { + vals[i] = 0.f; + } + } + + sum = warp_reduce_sum(sum); + + const float inv_sum = 1.0f / sum; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + vals[i] *= inv_sum; + } + } +} + /* This kernel does the following: - 1. softmax over the logits per token [n_experts, n_tokens] + 1. optionally softmax over the logits per token [n_experts, n_tokens] 2. argmax reduce over the top-k (n_experts_used) logits 3. write weights + ids to global memory - 4. optionally normalize the weights + 4. optionally normalize the weights or apply softmax over the selected logits It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template +template __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, float * weights, int32_t * ids, @@ -30,51 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; - float logits_r[experts_per_thread]; + float wt[experts_per_thread]; #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { - const int expert = i + threadIdx.x; - logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; + const int expert = i + threadIdx.x; + wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; } - float max_val = logits_r[0]; - -#pragma unroll - for (int i = 1; i < experts_per_thread; i++) { - const float val = logits_r[i]; - max_val = max(val, max_val); + if constexpr (!delayed_softmax) { + softmax_warp_inplace(wt, n_experts, threadIdx.x); } - max_val = warp_reduce_max(max_val); - - float wt[experts_per_thread]; - float tmp = 0.f; - -#pragma unroll - for (int i = 0; i < experts_per_thread; i++) { - const float val = logits_r[i]; - wt[i] = expf(val - max_val); - tmp += wt[i]; - } + //at this point, each thread holds either a portion of the softmax distribution + //or the raw logits. We do the argmax reduce over n_expert_used, each time marking + //the expert weight as -inf to exclude from the next iteration - tmp = warp_reduce_sum(tmp); + float wt_sum = 0.f; - const float inv_sum = 1.0f / tmp; + float output_weights[experts_per_thread]; #pragma unroll for (int i = 0; i < experts_per_thread; i++) { - wt[i] = wt[i] * inv_sum; + output_weights[i] = 0.f; } - //at this point, each thread holds a portion of softmax, - //we do the argmax reduce over n_expert_used, each time marking - //the expert weight as -inf to exclude from the next iteration - - float wt_sum = 0.f; - - float output_weights[experts_per_thread]; - for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; @@ -121,6 +146,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } + if constexpr (delayed_softmax) { + softmax_warp_inplace(output_weights, n_expert_used, threadIdx.x); + } + #pragma unroll for (int i = 0; i < experts_per_thread; i++) { const int idx = i * WARP_SIZE + threadIdx.x; @@ -130,7 +159,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } -template +template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, @@ -138,6 +167,8 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const int n_rows, const int n_expert, const int n_expert_used) { + static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); + const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); @@ -145,43 +176,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: - topk_moe_cuda<1, with_norm> + topk_moe_cuda<1, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 2: - topk_moe_cuda<2, with_norm> + topk_moe_cuda<2, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 4: - topk_moe_cuda<4, with_norm> + topk_moe_cuda<4, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 8: - topk_moe_cuda<8, with_norm> + topk_moe_cuda<8, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 16: - topk_moe_cuda<16, with_norm> + topk_moe_cuda<16, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 32: - topk_moe_cuda<32, with_norm> + topk_moe_cuda<32, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 64: - topk_moe_cuda<64, with_norm> + topk_moe_cuda<64, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 128: - topk_moe_cuda<128, with_norm> + topk_moe_cuda<128, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 256: - topk_moe_cuda<256, with_norm> + topk_moe_cuda<256, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 512: - topk_moe_cuda<512, with_norm> + topk_moe_cuda<512, with_norm, delayed_softmax> <<>>(logits, weights, ids, n_rows, n_expert_used); break; default: @@ -194,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * ids, - const bool with_norm) { + const bool with_norm, + const bool delayed_softmax) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -202,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const int n_experts = logits->ne[0]; const int n_rows = logits->ne[1]; - const float * logits_d = (const float *) logits->src[0]->data; + const float * logits_d = (const float *) logits->data; float * weights_d = (float *) weights->data; int32_t * ids_d = (int32_t *) ids->data; @@ -213,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, if (with_norm) { launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); } else { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + if (delayed_softmax) { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } else { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } } } @@ -246,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso return true; } -std::initializer_list ggml_cuda_topk_moe_ops(bool norm) { +std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; @@ -254,8 +290,19 @@ std::initializer_list ggml_cuda_topk_moe_ops(bool norm) { static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }; + static std::initializer_list delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; + + GGML_ASSERT(!norm || !delayed_softmax); + + if (delayed_softmax) { + return delayed_softmax_ops; + } + if (norm) { return norm_ops; } + return no_norm_ops; } diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 6613fb56507ea..cc2fbfe9e6649 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -6,9 +6,10 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, - ggml_tensor * top_k, - const bool with_norm); + ggml_tensor * ids, + const bool with_norm, + const bool delayed_softmax = false); bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); -std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm); +std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 82bb55ea0e184..6298244c7604f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4545,14 +4545,21 @@ struct test_topk_moe: public test_case { const std::array ne; const int n_expert_used; const bool with_norm; - test_topk_moe(std::array ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false) - : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) { + const bool delayed_softmax; + + test_topk_moe(std::array ne = { 10, 5, 1, 1 }, + int n_expert_used = 1, + bool with_norm = false, + bool delayed_softmax = false) : + ne(ne), + n_expert_used(n_expert_used), + with_norm(with_norm), + delayed_softmax(delayed_softmax) { GGML_ASSERT(n_expert_used <= ne[0]); + GGML_ASSERT(!(with_norm && delayed_softmax)); } - std::string vars() override { - return VARS_TO_STR3(ne, n_expert_used, with_norm); - } + std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); } std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -4566,11 +4573,17 @@ struct test_topk_moe: public test_case { const int n_tokens = ne[1]; ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); - ggml_tensor * probs = ggml_soft_max(ctx, logits); + ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + if (delayed_softmax) { + out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); + out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens] + out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); + } + if (with_norm) { out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] @@ -6843,6 +6856,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); } + test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true)); + test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true)); + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true));