Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2818,8 +2818,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
#endif

//TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
std::initializer_list<enum ggml_op> topk_moe_ops =
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);

if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {

Expand Down Expand Up @@ -2855,6 +2859,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}

if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
std::equal(ops.begin(), ops.end(), topk_moe_ops_delayed_softmax.begin())) {
if (node_idx + topk_moe_ops_delayed_softmax.size() > (size_t) cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_ops_delayed_softmax.size(); i++) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_delayed_softmax.begin()[i]) {
return false;
}
}

ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];

if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
return true;
}
}

if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
Expand Down Expand Up @@ -2948,19 +2971,32 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i+8];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false);
i += 8;
continue;
}

if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i+4];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
continue;
}

if (ggml_cuda_can_fuse(cgraph, i,
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 5];
ggml_tensor * ids = cgraph->nodes[i + 1];

ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
/*delayed_softmax*/ true);
i += 5;
continue;
}

if (node->op == GGML_OP_ADD) {
int n_fuse = 0;
ggml_op ops[8];
Expand Down
145 changes: 96 additions & 49 deletions ggml/src/ggml-cuda/topk-moe.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,61 @@

#include <initializer_list>

// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
float max_val = -INFINITY;

#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
max_val = max(max_val, vals[i]);
}
}

max_val = warp_reduce_max(max_val);

float sum = 0.f;

#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
const float val = expf(vals[i] - max_val);
vals[i] = val;
sum += val;
} else {
vals[i] = 0.f;
}
}

sum = warp_reduce_sum(sum);

const float inv_sum = 1.0f / sum;

#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
vals[i] *= inv_sum;
}
}
}

/*
This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
1. optionally softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory
4. optionally normalize the weights
4. optionally normalize the weights or apply softmax over the selected logits

It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <int n_experts, bool with_norm>
template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
Expand All @@ -30,51 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *

constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;

float logits_r[experts_per_thread];
float wt[experts_per_thread];

#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
const int expert = i + threadIdx.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
}

float max_val = logits_r[0];

#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
if constexpr (!delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}

max_val = warp_reduce_max(max_val);

float wt[experts_per_thread];
float tmp = 0.f;

#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = expf(val - max_val);
tmp += wt[i];
}
//at this point, each thread holds either a portion of the softmax distribution
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration

tmp = warp_reduce_sum(tmp);
float wt_sum = 0.f;

const float inv_sum = 1.0f / tmp;
float output_weights[experts_per_thread];

#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
output_weights[i] = 0.f;
}

//at this point, each thread holds a portion of softmax,
//we do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration

float wt_sum = 0.f;

float output_weights[experts_per_thread];

for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
int max_expert = threadIdx.x;
Expand Down Expand Up @@ -121,6 +146,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}

if constexpr (delayed_softmax) {
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
}

#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
Expand All @@ -130,58 +159,60 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}

template <bool with_norm>
template <bool with_norm, bool delayed_softmax = false>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert,
const int n_expert_used) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");

const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream();

switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm>
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 2:
topk_moe_cuda<2, with_norm>
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 4:
topk_moe_cuda<4, with_norm>
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 8:
topk_moe_cuda<8, with_norm>
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 16:
topk_moe_cuda<16, with_norm>
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 32:
topk_moe_cuda<32, with_norm>
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 64:
topk_moe_cuda<64, with_norm>
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 128:
topk_moe_cuda<128, with_norm>
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 256:
topk_moe_cuda<256, with_norm>
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 512:
topk_moe_cuda<512, with_norm>
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
default:
Expand All @@ -194,15 +225,16 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm) {
const bool with_norm,
const bool delayed_softmax) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);

const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];

const float * logits_d = (const float *) logits->src[0]->data;
const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;

Expand All @@ -213,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
if (with_norm) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
}
}
}

Expand Down Expand Up @@ -246,16 +282,27 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return true;
}

std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };

static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };

static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };

GGML_ASSERT(!norm || !delayed_softmax);

if (delayed_softmax) {
return delayed_softmax_ops;
}

if (norm) {
return norm_ops;
}

return no_norm_ops;
}
7 changes: 4 additions & 3 deletions ggml/src/ggml-cuda/topk-moe.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * top_k,
const bool with_norm);
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false);

bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);

std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
Loading