Skip to content

Commit 3c7939c

Browse files
jeffbolznvggerganov
authored andcommitted
vulkan: Split large mul_mat_id to fit in shared memory (llama/14451)
1 parent 6fc80e8 commit 3c7939c

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6003,7 +6003,30 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
60036003
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
60046004
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
60056005
} else {
6006-
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
6006+
// Split based on number of ids, to fit in shared memory
6007+
const uint32_t nei0 = (uint32_t)src2->ne[0];
6008+
const uint32_t nei1 = (uint32_t)src2->ne[1];
6009+
6010+
GGML_ASSERT(nei0 <= 4096);
6011+
const uint32_t split_size = std::min(nei1, 4096u / nei0);
6012+
6013+
ggml_tensor src1_copy = *src1;
6014+
ggml_tensor src2_copy = *src2;
6015+
ggml_tensor dst_copy = *dst;
6016+
6017+
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
6018+
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
6019+
6020+
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
6021+
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
6022+
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
6023+
6024+
src1_copy.ne[2] = n_tokens;
6025+
src2_copy.ne[1] = n_tokens;
6026+
dst_copy.ne[2] = n_tokens;
6027+
6028+
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
6029+
}
60076030
}
60086031
}
60096032

@@ -10136,9 +10159,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1013610159
ggml_type src0_type = op->src[0]->type;
1013710160
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
1013810161
const vk_device& device = ggml_vk_get_device(ctx->device);
10139-
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
10140-
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10141-
return false;
10162+
if (op->op == GGML_OP_MUL_MAT_ID) {
10163+
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
10164+
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10165+
return false;
10166+
}
10167+
// Check against size of shared memory variable
10168+
if (op->src[2]->ne[0] > 4096) {
10169+
return false;
10170+
}
1014210171
}
1014310172
switch (src0_type) {
1014410173
case GGML_TYPE_F32:

0 commit comments

Comments
 (0)