-
Notifications
You must be signed in to change notification settings - Fork 13.4k
CUDA: add a fused top-K MoE kernel #16130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
a208f5c
9fc0396
8b780cc
ce867aa
2930668
240b2c1
e772b28
53acfe6
33856e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
#include "ggml-cuda/common.cuh" | ||
#include "ggml.h" | ||
#include "topk-moe.cuh" | ||
|
||
#include <initializer_list> | ||
|
||
/* | ||
This kernel does the following: | ||
1. softmax over the logits per token [n_experts, n_tokens] | ||
2. argmax reduce over the top-k (n_experts_used) logits | ||
3. write weights + ids to global memory | ||
4. optionally normalize the weights | ||
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models | ||
*/ | ||
template <size_t n_experts, bool with_norm> | ||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, | ||
float * weights, | ||
int32_t * ids, | ||
const int n_rows, | ||
const int n_expert_used) { | ||
const int row = blockIdx.x * blockDim.y + threadIdx.y; | ||
if (row >= n_rows) { | ||
return; | ||
} | ||
|
||
logits += n_experts * row; | ||
weights += n_expert_used * row; | ||
ids += n_experts * row; | ||
|
||
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; | ||
|
||
float logits_r[experts_per_thread]; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < n_experts; i += WARP_SIZE) { | ||
const int expert = i + threadIdx.x; | ||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; | ||
} | ||
|
||
float max_val = logits_r[0]; | ||
|
||
#pragma unroll | ||
for (int i = 1; i < experts_per_thread; i++) { | ||
const float val = logits_r[i]; | ||
max_val = max(val, max_val); | ||
} | ||
|
||
max_val = warp_reduce_max(max_val); | ||
|
||
float wt[experts_per_thread]; | ||
float tmp = 0.f; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < experts_per_thread; i++) { | ||
const float val = logits_r[i]; | ||
wt[i] = expf(val - max_val); | ||
tmp += wt[i]; | ||
} | ||
|
||
tmp = warp_reduce_sum(tmp); | ||
|
||
const float inv_sum = 1.0f / tmp; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < experts_per_thread; i++) { | ||
wt[i] = wt[i] * inv_sum; | ||
} | ||
|
||
//at this point, each thread holds a portion of softmax, | ||
//we do the argmax reduce over n_expert_used, each time marking | ||
//the expert weight as -inf to exclude from the next iteration | ||
|
||
float wt_sum = 0.f; | ||
|
||
extern __shared__ float data_topk_shared[]; | ||
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; | ||
|
||
for (int k = 0; k < n_expert_used; k++) { | ||
float max_val = wt[0]; | ||
int max_expert = threadIdx.x; | ||
|
||
#pragma unroll | ||
for (int i = 1; i < experts_per_thread; i++) { | ||
const int expert = threadIdx.x + i * WARP_SIZE; | ||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { | ||
max_val = wt[i]; | ||
max_expert = expert; | ||
} | ||
} | ||
|
||
#pragma unroll | ||
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { | ||
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); | ||
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); | ||
if (val > max_val || (val == max_val && expert < max_expert)) { | ||
max_val = val; | ||
max_expert = expert; | ||
} | ||
} | ||
|
||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { | ||
wt[max_expert / WARP_SIZE] = -INFINITY; | ||
|
||
wt_shared_ptr[k] = max_val; | ||
ids[k] = max_expert; | ||
if constexpr (with_norm) { | ||
wt_sum += max_val; | ||
} | ||
} | ||
} | ||
|
||
if constexpr (with_norm) { | ||
wt_sum = warp_reduce_sum(wt_sum); | ||
const float inv_sum = 1.0f / wt_sum; | ||
|
||
if (threadIdx.x == 0) { | ||
for (int i = 0; i < n_expert_used; i++) { | ||
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; | ||
} | ||
} | ||
am17an marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
|
||
if (threadIdx.x == 0) { | ||
for (int i = 0; i < n_expert_used; i++) { | ||
weights[i] = wt_shared_ptr[i]; | ||
} | ||
} | ||
} | ||
|
||
template <bool with_norm> | ||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, | ||
const float * logits, | ||
float * weights, | ||
int32_t * ids, | ||
const int n_rows, | ||
const int n_expert, | ||
const int n_expert_used) { | ||
const int rows_per_block = 4; | ||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); | ||
dim3 block_dims(WARP_SIZE, rows_per_block, 1); | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); | ||
|
||
switch (n_expert) { | ||
case 1: | ||
topk_moe_cuda<1, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 2: | ||
topk_moe_cuda<2, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 4: | ||
topk_moe_cuda<4, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 8: | ||
topk_moe_cuda<8, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 16: | ||
topk_moe_cuda<16, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 32: | ||
topk_moe_cuda<32, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 64: | ||
topk_moe_cuda<64, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 128: | ||
topk_moe_cuda<128, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 256: | ||
topk_moe_cuda<256, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
case 512: | ||
topk_moe_cuda<512, with_norm> | ||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
break; | ||
default: | ||
GGML_ASSERT(false && "fatal error"); | ||
break; | ||
} | ||
} | ||
|
||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, | ||
const ggml_tensor * logits, | ||
ggml_tensor * weights, | ||
ggml_tensor * ids, | ||
const bool with_norm) { | ||
GGML_ASSERT(logits->type == GGML_TYPE_F32); | ||
GGML_ASSERT(weights->type == GGML_TYPE_F32); | ||
GGML_ASSERT(ids->type == GGML_TYPE_I32); | ||
|
||
const int n_experts = logits->ne[0]; | ||
const int n_rows = logits->ne[1]; | ||
|
||
const float * logits_d = (const float *) logits->src[0]->data; | ||
float * weights_d = (float *) weights->data; | ||
int32_t * ids_d = (int32_t *) ids->data; | ||
|
||
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); | ||
|
||
cudaStream_t stream = ctx.stream(); | ||
|
||
const int n_expert_used = weights->ne[1]; | ||
|
||
if (with_norm) { | ||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); | ||
} else { | ||
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); | ||
} | ||
} | ||
|
||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be robust, this ought to check that there aren't any unanticipated uses of tensors whose writes are elided. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, what do you mean by "elided"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Writes to intermediate tensors that are removed. This is why we have the ggml_node_has_n_uses checks in ggml_can_fuse - to make sure that nothing outside of this fusion region is using any of the intermediate tensors. But this fusion is complex enough that ggml_can_fuse can't handle it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes agree, I'm working on |
||
float scale = 1.0f; | ||
float max_bias = 0.0f; | ||
|
||
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); | ||
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); | ||
|
||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { | ||
return false; | ||
} | ||
|
||
if (scale != 1.0f || max_bias != 0.0f) { | ||
return false; | ||
} | ||
|
||
// don't fuse when masks or sinks are present | ||
if (softmax->src[1] || softmax->src[2]) { | ||
return false; | ||
} | ||
|
||
const int n_expert = softmax->ne[0]; | ||
// n_expert must be a power of 2 | ||
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) { | ||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, | ||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, | ||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; | ||
|
||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, | ||
GGML_OP_VIEW, GGML_OP_GET_ROWS }; | ||
|
||
if (norm) { | ||
return norm_ops; | ||
} | ||
return no_norm_ops; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#include "common.cuh" | ||
#include "ggml.h" | ||
|
||
#include <initializer_list> | ||
|
||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, | ||
const ggml_tensor * logits, | ||
ggml_tensor * weights, | ||
ggml_tensor * top_k, | ||
const bool with_norm); | ||
|
||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); | ||
|
||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm); |
Uh oh!
There was an error while loading. Please reload this page.