Skip to content

Commit 3e024de

Browse files
ikawrakowIwan Kawrakow
andauthored
Vulkan: Disable multi-add for now (#581)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 8a0c38f commit 3e024de

File tree

2 files changed

+63
-29
lines changed

2 files changed

+63
-29
lines changed

ggml/src/ggml-vulkan.cpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5520,6 +5520,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
55205520

55215521
const uint64_t nei0 = ids->ne[0];
55225522
const uint64_t nei1 = ids->ne[1];
5523+
if (nei0*nei1 > 4096) {
5524+
fprintf(stderr, "%s: nei0 = %d, nei1 = %d\n", __func__, (int)nei0, (int)nei1);
5525+
}
55235526
GGML_ASSERT(nei0 * nei1 <= 4096);
55245527

55255528
const uint32_t nbi1 = ids->nb[1];
@@ -5915,7 +5918,30 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
59155918
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
59165919
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
59175920
} else {
5918-
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
5921+
// Split based on number of ids, to fit in shared memory
5922+
const uint32_t nei0 = (uint32_t)src2->ne[0];
5923+
const uint32_t nei1 = (uint32_t)src2->ne[1];
5924+
5925+
GGML_ASSERT(nei0 <= 4096);
5926+
const uint32_t split_size = std::min(nei1, 4096u / nei0);
5927+
5928+
ggml_tensor src1_copy = *src1;
5929+
ggml_tensor src2_copy = *src2;
5930+
ggml_tensor dst_copy = *dst;
5931+
5932+
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
5933+
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
5934+
5935+
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
5936+
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
5937+
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
5938+
5939+
src1_copy.ne[2] = n_tokens;
5940+
src2_copy.ne[1] = n_tokens;
5941+
dst_copy.ne[2] = n_tokens;
5942+
5943+
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
5944+
}
59195945
}
59205946
}
59215947

@@ -9510,9 +9536,15 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
95109536
ggml_type src0_type = op->src[0]->type;
95119537
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
95129538
const vk_device& device = ctx->device;
9513-
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
9514-
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
9515-
return false;
9539+
if (op->op == GGML_OP_MUL_MAT_ID) {
9540+
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
9541+
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
9542+
return false;
9543+
}
9544+
// Check against size of shared memory variable
9545+
if (op->src[2]->ne[0] > 4096) {
9546+
return false;
9547+
}
95169548
}
95179549
switch (src0_type) {
95189550
case GGML_TYPE_F32:
@@ -9580,6 +9612,10 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
95809612
default:
95819613
return false;
95829614
}
9615+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
9616+
// different head sizes of K and V are not supported yet
9617+
return false;
9618+
}
95839619
if (op->src[0]->type != GGML_TYPE_F32) {
95849620
return false;
95859621
}

src/llama.cpp

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9870,6 +9870,28 @@ llm_expert_gating_func_type gating_op,
98709870
cb(cur, "ffn_moe_weighted", il);
98719871
}
98729872

9873+
#ifdef GGML_USE_VULKAN
9874+
// aggregate experts
9875+
ggml_tensor * moe_out = nullptr;
9876+
//ggml_tensor * first_expert = nullptr;
9877+
for (int i = 0; i < n_expert_used; ++i) {
9878+
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
9879+
experts->nb[2], i*experts->nb[1]);
9880+
9881+
if (i == 0) {
9882+
moe_out = cur_expert;
9883+
} else {
9884+
moe_out = ggml_add(ctx, moe_out, cur_expert);
9885+
}
9886+
}
9887+
9888+
if (n_expert_used == 1) {
9889+
// avoid returning a non-contiguous tensor
9890+
moe_out = ggml_cont(ctx, moe_out);
9891+
}
9892+
9893+
return moe_out;
9894+
#else
98739895
if (n_expert_used == 1) {
98749896
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
98759897
}
@@ -9878,32 +9900,8 @@ llm_expert_gating_func_type gating_op,
98789900
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
98799901
}
98809902
return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
9903+
#endif
98819904

9882-
//// aggregate experts
9883-
//ggml_tensor * moe_out = nullptr;
9884-
////ggml_tensor * first_expert = nullptr;
9885-
//for (int i = 0; i < n_expert_used; ++i) {
9886-
// ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
9887-
// experts->nb[2], i*experts->nb[1]);
9888-
9889-
// if (i == 0) {
9890-
// moe_out = cur_expert;
9891-
// //first_expert = cur_expert;
9892-
// //printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert),
9893-
// // (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3],
9894-
// // (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]);
9895-
// } else {
9896-
// moe_out = ggml_add(ctx, moe_out, cur_expert);
9897-
// //printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert));
9898-
// }
9899-
//}
9900-
9901-
//if (n_expert_used == 1) {
9902-
// // avoid returning a non-contiguous tensor
9903-
// moe_out = ggml_cont(ctx, moe_out);
9904-
//}
9905-
9906-
//return moe_out;
99079905
}
99089906

99099907
static struct ggml_tensor * llm_build_kqv(

0 commit comments

Comments
 (0)