@@ -5520,6 +5520,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
55205520
55215521 const uint64_t nei0 = ids->ne[0];
55225522 const uint64_t nei1 = ids->ne[1];
5523+ if (nei0*nei1 > 4096) {
5524+ fprintf(stderr, "%s: nei0 = %d, nei1 = %d\n", __func__, (int)nei0, (int)nei1);
5525+ }
55235526 GGML_ASSERT(nei0 * nei1 <= 4096);
55245527
55255528 const uint32_t nbi1 = ids->nb[1];
@@ -5915,7 +5918,30 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
59155918 if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
59165919 ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
59175920 } else {
5918- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
5921+ // Split based on number of ids, to fit in shared memory
5922+ const uint32_t nei0 = (uint32_t)src2->ne[0];
5923+ const uint32_t nei1 = (uint32_t)src2->ne[1];
5924+
5925+ GGML_ASSERT(nei0 <= 4096);
5926+ const uint32_t split_size = std::min(nei1, 4096u / nei0);
5927+
5928+ ggml_tensor src1_copy = *src1;
5929+ ggml_tensor src2_copy = *src2;
5930+ ggml_tensor dst_copy = *dst;
5931+
5932+ for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
5933+ const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
5934+
5935+ src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
5936+ src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
5937+ dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
5938+
5939+ src1_copy.ne[2] = n_tokens;
5940+ src2_copy.ne[1] = n_tokens;
5941+ dst_copy.ne[2] = n_tokens;
5942+
5943+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
5944+ }
59195945 }
59205946}
59215947
@@ -9510,9 +9536,15 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
95109536 ggml_type src0_type = op->src[0]->type;
95119537 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
95129538 const vk_device& device = ctx->device;
9513- if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
9514- // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
9515- return false;
9539+ if (op->op == GGML_OP_MUL_MAT_ID) {
9540+ if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
9541+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
9542+ return false;
9543+ }
9544+ // Check against size of shared memory variable
9545+ if (op->src[2]->ne[0] > 4096) {
9546+ return false;
9547+ }
95169548 }
95179549 switch (src0_type) {
95189550 case GGML_TYPE_F32:
@@ -9580,6 +9612,10 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
95809612 default:
95819613 return false;
95829614 }
9615+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
9616+ // different head sizes of K and V are not supported yet
9617+ return false;
9618+ }
95839619 if (op->src[0]->type != GGML_TYPE_F32) {
95849620 return false;
95859621 }
0 commit comments