Skip to content

Commit 97cca07

Browse files
committed
vulkan: use kompute matmul for embedded
Use Kompute MAT_MUL shaders when the operating with an embedded GPU. Signed-off-by: Sergio Lopez <[email protected]>
1 parent 995e7ec commit 97cca07

File tree

2 files changed

+227
-7
lines changed

2 files changed

+227
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,15 @@ struct vk_device_struct {
216216

217217
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
218218
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
219+
220+
vk_pipeline pipeline_emb_mul_mat_mat_f32;
221+
vk_pipeline pipeline_emb_mul_mat_f16;
222+
vk_pipeline pipeline_emb_mul_mat_q4_0;
223+
vk_pipeline pipeline_emb_mul_mat_q4_1;
224+
vk_pipeline pipeline_emb_mul_mat_q4_k;
225+
vk_pipeline pipeline_emb_mul_mat_q6_k;
226+
vk_pipeline pipeline_emb_mul_mat_q8_0;
227+
219228
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
220229
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
221230
vk_pipeline pipeline_acc_f32;
@@ -1965,6 +1974,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
19651974
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
19661975
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
19671976

1977+
ggml_vk_create_pipeline(device, device->pipeline_emb_mul_mat_mat_f32, "emb_mul_mat_mat_f32", emb_mul_mat_mat_f32_len, emb_mul_mat_mat_f32_data, "main", 3, 14 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size}, 1);
1978+
ggml_vk_create_pipeline(device, device->pipeline_emb_mul_mat_f16, "emb_mul_mat_f16", emb_mul_mat_f16_len, emb_mul_mat_f16_data, "main", 3, 21 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size * 2}, 1);
1979+
ggml_vk_create_pipeline(device, device->pipeline_emb_mul_mat_q4_0, "emb_mul_mat_q4_0", emb_mul_mat_q4_0_len, emb_mul_mat_q4_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1);
1980+
ggml_vk_create_pipeline(device, device->pipeline_emb_mul_mat_q4_1, "emb_mul_mat_q4_1", emb_mul_mat_q4_1_len, emb_mul_mat_q4_1_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1);
1981+
ggml_vk_create_pipeline(device, device->pipeline_emb_mul_mat_q4_k, "emb_mul_mat_q4_k", emb_mul_mat_q4_k_len, emb_mul_mat_q4_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1982+
ggml_vk_create_pipeline(device, device->pipeline_emb_mul_mat_q6_k, "emb_mul_mat_q6_k", emb_mul_mat_q6_k_len, emb_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1);
1983+
ggml_vk_create_pipeline(device, device->pipeline_emb_mul_mat_q8_0, "emb_mul_mat_q8_0", emb_mul_mat_q8_0_len, emb_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1);
1984+
19681985
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
19691986
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
19701987
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -4371,6 +4388,167 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
43714388
}
43724389
}
43734390

4391+
static void ggml_vkemb_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4392+
VK_LOG_DEBUG("ggml_vkemb_mul_mat(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4393+
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
4394+
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4395+
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4396+
4397+
const uint64_t ne00 = src0->ne[0];
4398+
const uint64_t ne01 = src0->ne[1];
4399+
const uint64_t ne02 = src0->ne[2];
4400+
const uint64_t ne03 = src0->ne[3];
4401+
4402+
const uint64_t ne10 = src1->ne[0];
4403+
const uint64_t ne11 = src1->ne[1];
4404+
const uint64_t ne12 = src1->ne[2];
4405+
const uint64_t ne13 = src1->ne[3];
4406+
4407+
const uint64_t ne0 = dst->ne[0];
4408+
const uint64_t ne1 = dst->ne[1];
4409+
const uint64_t ne2 = dst->ne[2];
4410+
const uint64_t ne3 = dst->ne[3];
4411+
4412+
const uint64_t nb00 = src0->nb[0];
4413+
const uint64_t nb01 = src0->nb[1];
4414+
const uint64_t nb02 = src0->nb[2];
4415+
const uint64_t nb03 = src0->nb[3];
4416+
4417+
const uint64_t nb10 = src1->nb[0];
4418+
const uint64_t nb11 = src1->nb[1];
4419+
const uint64_t nb12 = src1->nb[2];
4420+
const uint64_t nb13 = src1->nb[3];
4421+
4422+
const uint64_t nb1 = dst->nb[1];
4423+
const uint64_t nb2 = dst->nb[2];
4424+
4425+
const uint64_t r2 = ne12 / ne02;
4426+
const uint64_t r3 = ne13 / ne03;
4427+
4428+
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
4429+
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
4430+
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
4431+
4432+
vk_buffer d_Qx = nullptr;
4433+
size_t qx_buf_offset = 0;
4434+
vk_buffer d_Qy = nullptr;
4435+
size_t qy_buf_offset = 0;
4436+
4437+
bool src0_uma = false;
4438+
bool src1_uma = false;
4439+
4440+
if (ctx->device->uma) {
4441+
ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
4442+
ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
4443+
src0_uma = d_Qx != nullptr;
4444+
src1_uma = d_Qy != nullptr;
4445+
}
4446+
4447+
const uint64_t x_ne = ne00 * ne01 * ne02 * ne03;
4448+
const uint64_t y_ne = ne10 * ne11 * ne12 * ne13;
4449+
const uint64_t d_ne = ne0 * ne1 * ne2 * ne3;
4450+
4451+
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
4452+
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4453+
const uint64_t d_sz = sizeof(float) * d_ne;
4454+
4455+
vk_pipeline pipeline;
4456+
switch (src0->type) {
4457+
case GGML_TYPE_F32:
4458+
pipeline = ctx->device->pipeline_emb_mul_mat_mat_f32;
4459+
break;
4460+
case GGML_TYPE_F16:
4461+
pipeline = ctx->device->pipeline_emb_mul_mat_f16;
4462+
break;
4463+
case GGML_TYPE_Q4_0:
4464+
pipeline = ctx->device->pipeline_emb_mul_mat_q4_0;
4465+
break;
4466+
case GGML_TYPE_Q4_1:
4467+
pipeline = ctx->device->pipeline_emb_mul_mat_q4_1;
4468+
break;
4469+
case GGML_TYPE_Q4_K:
4470+
pipeline = ctx->device->pipeline_emb_mul_mat_q4_k;
4471+
break;
4472+
case GGML_TYPE_Q6_K:
4473+
pipeline = ctx->device->pipeline_emb_mul_mat_q6_k;
4474+
break;
4475+
case GGML_TYPE_Q8_0:
4476+
pipeline = ctx->device->pipeline_emb_mul_mat_q8_0;
4477+
break;
4478+
default:
4479+
GGML_ABORT("vkemb_mul_mat: unsupported quantization type: %d", src0->type);
4480+
}
4481+
4482+
if (dryrun) {
4483+
// Request descriptor sets
4484+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
4485+
return;
4486+
}
4487+
4488+
vk_buffer d_D = dst_buf_ctx->dev_buffer;
4489+
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
4490+
GGML_ASSERT(d_D != nullptr);
4491+
GGML_ASSERT(d_D->size >= d_buf_offset + d_sz);
4492+
if (!src0_uma) {
4493+
d_Qx = src0_buf_ctx->dev_buffer;
4494+
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
4495+
GGML_ASSERT(d_Qx != nullptr);
4496+
}
4497+
if (!src1_uma) {
4498+
d_Qy = src1_buf_ctx->dev_buffer;
4499+
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
4500+
GGML_ASSERT(d_Qy != nullptr);
4501+
}
4502+
4503+
const uint64_t qx_buffer_offset = (qx_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
4504+
const uint64_t qx_shader_offset = qx_buf_offset - qx_buffer_offset;
4505+
4506+
const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
4507+
const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
4508+
4509+
const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
4510+
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
4511+
4512+
// compute
4513+
ggml_vk_sync_buffers(subctx);
4514+
switch (src0->type) {
4515+
case GGML_TYPE_F32:
4516+
{
4517+
const std::array<uint32_t, 14> pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne11, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb1, (uint32_t)nb2 };
4518+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 14 * sizeof(uint32_t), &pc, { (uint32_t)ne01, (uint32_t)ne11, (uint32_t)std::max(ne12, ne02) });
4519+
break;
4520+
}
4521+
case GGML_TYPE_F16:
4522+
{
4523+
const std::array<uint32_t, 21> pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)nb00, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)ne10, (uint32_t)ne11, (uint32_t)ne12, (uint32_t)nb10, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)r2, (uint32_t)r3 };
4524+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 21 * sizeof(uint32_t), &pc, { (uint32_t)ne01, (uint32_t)((ne11 + 3) / 4), (uint32_t)(ne12 * ne13) });
4525+
break;
4526+
}
4527+
case GGML_TYPE_Q4_0:
4528+
case GGML_TYPE_Q4_1:
4529+
case GGML_TYPE_Q8_0:
4530+
{
4531+
const std::array<uint32_t, 18> pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne10, (uint32_t)ne12, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 };
4532+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 7) / 8), (uint32_t)ne11, (uint32_t)(ne12 * ne13) });
4533+
break;
4534+
}
4535+
case GGML_TYPE_Q4_K:
4536+
{
4537+
const std::array<uint32_t, 18> pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 };
4538+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 3) / 4), (uint32_t)ne11, (uint32_t)(ne12 * ne13) });
4539+
break;
4540+
}
4541+
case GGML_TYPE_Q6_K:
4542+
{
4543+
const std::array<uint32_t, 18> pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 };
4544+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 1) / 2), (uint32_t)ne11, (uint32_t)(ne12 * ne13) });
4545+
break;
4546+
}
4547+
default:
4548+
GGML_ABORT("vkemb_mul_mat: unsupported quantization type: %d", src0->type);
4549+
}
4550+
}
4551+
43744552
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
43754553
VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
43764554
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -7086,8 +7264,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
70867264

70877265
break;
70887266
case GGML_OP_MUL_MAT:
7089-
ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun);
7090-
7267+
if (ctx->device->embedded) {
7268+
ggml_vkemb_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun);
7269+
} else {
7270+
ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun);
7271+
}
70917272
break;
70927273
case GGML_OP_MUL_MAT_ID:
70937274
ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
@@ -8108,7 +8289,33 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
81088289
}
81098290

81108291
static bool ggml_backend_vkemb_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
8111-
return ggml_backend_vk_device_supports_op(dev, op);
8292+
switch (op->op) {
8293+
case GGML_OP_MUL_MAT:
8294+
if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
8295+
return false;
8296+
8297+
switch (op->src[0]->type) {
8298+
case GGML_TYPE_F32:
8299+
return op->ne[3] == 1;
8300+
case GGML_TYPE_Q8_0:
8301+
// TODO (slp) - Fix Q8_0 with permutations
8302+
if (ggml_is_permuted(op->src[0]) || ggml_is_permuted(op->src[1])) {
8303+
return false;
8304+
}
8305+
case GGML_TYPE_Q6_K:
8306+
case GGML_TYPE_F16:
8307+
case GGML_TYPE_Q4_0:
8308+
case GGML_TYPE_Q4_1:
8309+
case GGML_TYPE_Q4_K:
8310+
return true;
8311+
default:
8312+
return false;
8313+
}
8314+
case GGML_OP_MUL_MAT_ID:
8315+
return false;
8316+
default:
8317+
return ggml_backend_vk_device_supports_op(dev, op);
8318+
}
81128319
}
81138320

81148321
static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,15 @@ static uint32_t compile_count = 0;
198198
static std::mutex compile_count_mutex;
199199
static std::condition_variable compile_count_cond;
200200

201-
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
201+
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, bool is_embed = false) {
202202
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
203203
std::string out_fname = join_paths(output_dir, name + ".spv");
204-
std::string in_path = join_paths(input_dir, in_fname);
204+
std::string in_path;
205+
if (is_embed) {
206+
in_path = join_paths(input_dir + "/../../ggml-kompute/kompute-shaders", in_fname);
207+
} else {
208+
in_path = join_paths(input_dir, in_fname);
209+
}
205210

206211
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
207212

@@ -261,7 +266,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
261266
}
262267

263268
static std::vector<std::future<void>> compiles;
264-
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
269+
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, bool is_embed = false) {
265270
{
266271
// wait until fewer than N compiles are in progress.
267272
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@@ -272,7 +277,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
272277
}
273278
compile_count++;
274279
}
275-
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
280+
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc, is_embed));
276281
}
277282

278283
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
@@ -490,6 +495,14 @@ void process_shaders() {
490495

491496
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
492497

498+
string_to_spv("emb_mul_mat_mat_f32", "op_mul_mat_mat_f32.comp", {}, true, false, false, false, true);
499+
string_to_spv("emb_mul_mat_f16", "op_mul_mat_f16.comp", {}, true, false, false, false, true);
500+
string_to_spv("emb_mul_mat_q4_0", "op_mul_mat_q4_0.comp", {}, true, false, false, false, true);
501+
string_to_spv("emb_mul_mat_q4_1", "op_mul_mat_q4_1.comp", {}, true, false, false, false, true);
502+
string_to_spv("emb_mul_mat_q4_k", "op_mul_mat_q4_k.comp", {}, true, false, false, false, true);
503+
string_to_spv("emb_mul_mat_q6_k", "op_mul_mat_q6_k.comp", {}, true, false, false, false, true);
504+
string_to_spv("emb_mul_mat_q8_0", "op_mul_mat_q8_0.comp", {}, true, false, false, false, true);
505+
493506
for (auto &c : compiles) {
494507
c.wait();
495508
}

0 commit comments

Comments
 (0)