diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index afe4aee2403b2..c588da2bb9e93 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt_sum = 0.f; - extern __shared__ float data_topk_shared[]; - float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; + float output_weights[experts_per_thread]; for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; @@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } + if ((k & (WARP_SIZE - 1)) == threadIdx.x) { + output_weights[k / WARP_SIZE] = max_val; + } + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { wt[max_expert / WARP_SIZE] = -INFINITY; - wt_shared_ptr[k] = max_val; - ids[k] = max_expert; + ids[k] = max_expert; if constexpr (with_norm) { wt_sum += max_val; } @@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * const float inv_sum = 1.0f / wt_sum; for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { - wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; + output_weights[i] *= inv_sum; } } - for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { - weights[i] = wt_shared_ptr[i]; +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = i * WARP_SIZE + threadIdx.x; + if (idx < n_expert_used) { + weights[idx] = output_weights[i]; + } } } @@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); - const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); - switch (n_expert) { case 1: topk_moe_cuda<1, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 2: topk_moe_cuda<2, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 4: topk_moe_cuda<4, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 8: topk_moe_cuda<8, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 16: topk_moe_cuda<16, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 32: topk_moe_cuda<32, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 64: topk_moe_cuda<64, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 128: topk_moe_cuda<128, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 256: topk_moe_cuda<256, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 512: topk_moe_cuda<512, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; default: GGML_ASSERT(false && "fatal error");